mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added hf models import tab and route for getting available hf models
This commit is contained in:
parent
efea1a8a7d
commit
f7cd3cf1f4
@ -11,7 +11,7 @@ from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
from starlette.exceptions import HTTPException
|
||||
from typing_extensions import Annotated
|
||||
|
||||
@ -29,6 +29,7 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import RemoteModelFile
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -246,6 +247,29 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/hugging_face",
|
||||
operation_id="get_hugging_face_models",
|
||||
responses={
|
||||
200: {"description": "Hugging Face repo scanned successfully"},
|
||||
400: {"description": "Invalid hugging face repo"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[AnyHttpUrl],
|
||||
)
|
||||
async def get_hugging_face_models(
|
||||
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
|
||||
) -> List[AnyHttpUrl]:
|
||||
get_hugging_face_models = ApiDependencies.invoker.services.model_manager.install.get_hugging_face_models
|
||||
get_hugging_face_models(hugging_face_repo)
|
||||
|
||||
result = get_hugging_face_models(
|
||||
source=hugging_face_repo,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
"/i/{key}",
|
||||
operation_id="update_model_record",
|
||||
|
@ -20,6 +20,7 @@ from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
|
||||
from invokeai.backend.model_manager.config import ModelSourceType
|
||||
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import RemoteModelFile
|
||||
|
||||
|
||||
class InstallStatus(str, Enum):
|
||||
@ -405,6 +406,19 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_hugging_face_models(
|
||||
self,
|
||||
source: str,
|
||||
) -> List[AnyHttpUrl]:
|
||||
"""Get the available models in a HuggingFace repo.
|
||||
|
||||
:param source: HuggingFace repo string
|
||||
|
||||
This will get the urls for the available models in the indicated,
|
||||
repo, and return them as a list of AnyHttpUrl strings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
@ -233,6 +233,22 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_jobs.append(install_job)
|
||||
return install_job
|
||||
|
||||
def get_hugging_face_models(self, source: str) -> List[AnyHttpUrl]:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
access_token = HfFolder.get_token()
|
||||
if not access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source)
|
||||
self._logger.info(f"metadata is {metadata}")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
remote_files = metadata.download_urls(
|
||||
session=self._session,
|
||||
)
|
||||
|
||||
# return array of remote_files.url
|
||||
return [x.url for x in remote_files]
|
||||
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
|
@ -51,7 +51,15 @@ def filter_files(
|
||||
(
|
||||
"learned_embeds.bin",
|
||||
"ip_adapter.bin",
|
||||
"lora_weights.safetensors",
|
||||
# jennifer added a bunch of these, probably will break something
|
||||
".safetensors",
|
||||
".bin",
|
||||
".onnx",
|
||||
".xml",
|
||||
".pth",
|
||||
".pt",
|
||||
".ckpt",
|
||||
".msgpack",
|
||||
"weights.pb",
|
||||
"onnx_data",
|
||||
)
|
||||
@ -71,7 +79,8 @@ def filter_files(
|
||||
paths = [x for x in paths if x.parent == Path(subfolder)]
|
||||
|
||||
# _filter_by_variant uniquifies the paths and returns a set
|
||||
return sorted(_filter_by_variant(paths, variant))
|
||||
# jennifer removed the filter since it removed models, probably will break something but i dont understand why it removes valid models :|
|
||||
return sorted(paths)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -762,6 +762,8 @@
|
||||
"formMessageDiffusersVAELocationDesc": "If not provided, InvokeAI will look for the VAE file inside the model location given above.",
|
||||
"height": "Height",
|
||||
"heightValidationMsg": "Default height of your model.",
|
||||
"huggingFace": "Hugging Face",
|
||||
"huggingFaceRepoID": "HuggingFace Repo ID",
|
||||
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
|
||||
"imageEncoderModelId": "Image Encoder Model ID",
|
||||
"importModels": "Import Models",
|
||||
|
@ -0,0 +1,49 @@
|
||||
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
import { HuggingFaceResults } from './HuggingFaceResults';
|
||||
|
||||
export const HuggingFaceForm = () => {
|
||||
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
|
||||
const [errorMessage, setErrorMessage] = useState('');
|
||||
const { t } = useTranslation();
|
||||
|
||||
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
|
||||
|
||||
const scanFolder = useCallback(async () => {
|
||||
_getHuggingFaceModels(huggingFaceRepo).catch((error) => {
|
||||
if (error) {
|
||||
setErrorMessage(error.data.detail);
|
||||
}
|
||||
});
|
||||
}, [_getHuggingFaceModels, huggingFaceRepo]);
|
||||
|
||||
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setHuggingFaceRepo(e.target.value);
|
||||
setErrorMessage('');
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="100%">
|
||||
<FormControl isInvalid={!!errorMessage.length} w="full">
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
||||
<Flex direction="column" w="full">
|
||||
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
|
||||
<Input value={huggingFaceRepo} onChange={handleSetHuggingFaceRepo} />
|
||||
</Flex>
|
||||
|
||||
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={huggingFaceRepo.length === 0}>
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
{data && <HuggingFaceResults results={data} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,59 @@
|
||||
import { Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import { useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { IoAdd } from 'react-icons/io5';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
type Props = {
|
||||
result: string;
|
||||
};
|
||||
export const HuggingFaceResultItem = ({ result }: Props) => {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const handleQuickAdd = useCallback(() => {
|
||||
installModel({ source: result })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}, [installModel, result, dispatch, t]);
|
||||
|
||||
return (
|
||||
<Flex justifyContent="space-between" w="100%">
|
||||
<Flex fontSize="sm" flexDir="column">
|
||||
<Text fontWeight="semibold">{result.split('/').slice(-1)[0]}</Text>
|
||||
<Text variant="subtext">{result}</Text>
|
||||
</Flex>
|
||||
<Box>
|
||||
<Tooltip label={t('modelManager.quickAdd')}>
|
||||
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
|
||||
</Tooltip>
|
||||
</Box>
|
||||
</Flex>
|
||||
);
|
||||
};
|
@ -0,0 +1,119 @@
|
||||
import {
|
||||
Button,
|
||||
Divider,
|
||||
Flex,
|
||||
Heading,
|
||||
IconButton,
|
||||
Input,
|
||||
InputGroup,
|
||||
InputRightElement,
|
||||
} from '@invoke-ai/ui-library';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
|
||||
import { addToast } from 'features/system/store/systemSlice';
|
||||
import { makeToast } from 'features/system/util/makeToast';
|
||||
import type { ChangeEventHandler } from 'react';
|
||||
import { useCallback, useState } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { PiXBold } from 'react-icons/pi';
|
||||
import { useInstallModelMutation } from 'services/api/endpoints/models';
|
||||
|
||||
import { HuggingFaceResultItem } from './HuggingFaceResultItem';
|
||||
|
||||
type HuggingFaceResultsProps = {
|
||||
// results: HuggingFaceFolderResponse;
|
||||
results: string[];
|
||||
};
|
||||
|
||||
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
||||
const { t } = useTranslation();
|
||||
const [searchTerm, setSearchTerm] = useState('');
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const [installModel] = useInstallModelMutation();
|
||||
|
||||
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
|
||||
setSearchTerm(e.target.value.trim());
|
||||
}, []);
|
||||
|
||||
const clearSearch = useCallback(() => {
|
||||
setSearchTerm('');
|
||||
}, []);
|
||||
|
||||
const handleAddAll = useCallback(() => {
|
||||
for (const result of results) {
|
||||
installModel({ source: result })
|
||||
.unwrap()
|
||||
.then((_) => {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: t('toast.modelAddedSimple'),
|
||||
status: 'success',
|
||||
})
|
||||
)
|
||||
);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (error) {
|
||||
dispatch(
|
||||
addToast(
|
||||
makeToast({
|
||||
title: `${error.data.detail} `,
|
||||
status: 'error',
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
}, [results, installModel, dispatch, t]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<Divider mt={4} />
|
||||
<Flex flexDir="column" gap={2} mt={4} height="100%">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Heading fontSize="md" as="h4">
|
||||
{t('modelManager.availableModels')}
|
||||
</Heading>
|
||||
<Flex alignItems="center" gap="4">
|
||||
<Button onClick={handleAddAll} isDisabled={results.length === 0}>
|
||||
{t('modelManager.addAll')}
|
||||
</Button>
|
||||
<InputGroup maxW="300px" size="xs">
|
||||
<Input
|
||||
placeholder={t('modelManager.search')}
|
||||
value={searchTerm}
|
||||
data-testid="board-search-input"
|
||||
onChange={handleSearch}
|
||||
size="xs"
|
||||
/>
|
||||
|
||||
{searchTerm && (
|
||||
<InputRightElement h="full" pe={2}>
|
||||
<IconButton
|
||||
size="sm"
|
||||
variant="link"
|
||||
aria-label={t('boards.clearSearch')}
|
||||
icon={<PiXBold />}
|
||||
onClick={clearSearch}
|
||||
/>
|
||||
</InputRightElement>
|
||||
)}
|
||||
</InputGroup>
|
||||
</Flex>
|
||||
</Flex>
|
||||
<Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}>
|
||||
<ScrollableContent>
|
||||
<Flex flexDir="column" gap={3}>
|
||||
{results.map((result) => (
|
||||
<HuggingFaceResultItem key={result} result={result} />
|
||||
))}
|
||||
</Flex>
|
||||
</ScrollableContent>
|
||||
</Flex>
|
||||
</Flex>
|
||||
</>
|
||||
);
|
||||
};
|
@ -1,6 +1,7 @@
|
||||
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
|
||||
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
|
||||
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
|
||||
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
|
||||
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
|
||||
@ -16,12 +17,16 @@ export const InstallModels = () => {
|
||||
<Tabs variant="collapse" height="100%">
|
||||
<TabList>
|
||||
<Tab>{t('common.simple')}</Tab>
|
||||
<Tab>{t('modelManager.huggingFace')}</Tab>
|
||||
<Tab>{t('modelManager.scan')}</Tab>
|
||||
</TabList>
|
||||
<TabPanels p={3} height="100%">
|
||||
<TabPanel>
|
||||
<InstallModelForm />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<HuggingFaceForm />
|
||||
</TabPanel>
|
||||
<TabPanel height="100%">
|
||||
<ScanModelsForm />
|
||||
</TabPanel>
|
||||
|
@ -258,6 +258,13 @@ export const modelsApi = api.injectEndpoints({
|
||||
};
|
||||
},
|
||||
}),
|
||||
getHuggingFaceModels: build.query<string[], string>({
|
||||
query: (hugging_face_repo) => {
|
||||
return {
|
||||
url: buildModelsUrl(`hugging_face?hugging_face_repo=${hugging_face_repo}`),
|
||||
};
|
||||
},
|
||||
}),
|
||||
listModelInstalls: build.query<ListModelInstallsResponse, void>({
|
||||
query: () => {
|
||||
return {
|
||||
@ -381,6 +388,7 @@ export const {
|
||||
useConvertModelMutation,
|
||||
useSyncModelsMutation,
|
||||
useLazyScanFolderQuery,
|
||||
useLazyGetHuggingFaceModelsQuery,
|
||||
useListModelInstallsQuery,
|
||||
useCancelModelInstallMutation,
|
||||
usePruneCompletedModelInstallsMutation,
|
||||
|
Loading…
Reference in New Issue
Block a user