added hf models import tab and route for getting available hf models

This commit is contained in:
Jennifer Player 2024-03-07 16:57:28 -05:00 committed by psychedelicious
parent efea1a8a7d
commit f7cd3cf1f4
10 changed files with 308 additions and 3 deletions

View File

@ -11,7 +11,7 @@ from fastapi import Body, Path, Query, Response, UploadFile
from fastapi.responses import FileResponse
from fastapi.routing import APIRouter
from PIL import Image
from pydantic import BaseModel, ConfigDict, Field
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
from starlette.exceptions import HTTPException
from typing_extensions import Annotated
@ -29,6 +29,7 @@ from invokeai.backend.model_manager.config import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata.metadata_base import RemoteModelFile
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@ -246,6 +247,29 @@ async def scan_for_models(
return scan_results
@model_manager_router.get(
"/hugging_face",
operation_id="get_hugging_face_models",
responses={
200: {"description": "Hugging Face repo scanned successfully"},
400: {"description": "Invalid hugging face repo"},
},
status_code=200,
response_model=List[AnyHttpUrl],
)
async def get_hugging_face_models(
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
) -> List[AnyHttpUrl]:
get_hugging_face_models = ApiDependencies.invoker.services.model_manager.install.get_hugging_face_models
get_hugging_face_models(hugging_face_repo)
result = get_hugging_face_models(
source=hugging_face_repo,
)
return result
@model_manager_router.patch(
"/i/{key}",
operation_id="update_model_record",

View File

@ -20,6 +20,7 @@ from invokeai.app.services.model_records import ModelRecordServiceBase
from invokeai.backend.model_manager import AnyModelConfig, ModelRepoVariant
from invokeai.backend.model_manager.config import ModelSourceType
from invokeai.backend.model_manager.metadata import AnyModelRepoMetadata
from invokeai.backend.model_manager.metadata.metadata_base import RemoteModelFile
class InstallStatus(str, Enum):
@ -405,6 +406,19 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def get_hugging_face_models(
self,
source: str,
) -> List[AnyHttpUrl]:
"""Get the available models in a HuggingFace repo.
:param source: HuggingFace repo string
This will get the urls for the available models in the indicated,
repo, and return them as a list of AnyHttpUrl strings.
"""
@abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""

View File

@ -233,6 +233,22 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_jobs.append(install_job)
return install_job
def get_hugging_face_models(self, source: str) -> List[AnyHttpUrl]:
# Add user's cached access token to HuggingFace requests
access_token = HfFolder.get_token()
if not access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
metadata = HuggingFaceMetadataFetch(self._session).from_id(source)
self._logger.info(f"metadata is {metadata}")
assert isinstance(metadata, ModelMetadataWithFiles)
remote_files = metadata.download_urls(
session=self._session,
)
# return array of remote_files.url
return [x.url for x in remote_files]
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs

View File

@ -51,7 +51,15 @@ def filter_files(
(
"learned_embeds.bin",
"ip_adapter.bin",
"lora_weights.safetensors",
# jennifer added a bunch of these, probably will break something
".safetensors",
".bin",
".onnx",
".xml",
".pth",
".pt",
".ckpt",
".msgpack",
"weights.pb",
"onnx_data",
)
@ -71,7 +79,8 @@ def filter_files(
paths = [x for x in paths if x.parent == Path(subfolder)]
# _filter_by_variant uniquifies the paths and returns a set
return sorted(_filter_by_variant(paths, variant))
# jennifer removed the filter since it removed models, probably will break something but i dont understand why it removes valid models :|
return sorted(paths)
@dataclass

View File

@ -762,6 +762,8 @@
"formMessageDiffusersVAELocationDesc": "If not provided, InvokeAI will look for the VAE file inside the model location given above.",
"height": "Height",
"heightValidationMsg": "Default height of your model.",
"huggingFace": "Hugging Face",
"huggingFaceRepoID": "HuggingFace Repo ID",
"ignoreMismatch": "Ignore Mismatches Between Selected Models",
"imageEncoderModelId": "Image Encoder Model ID",
"importModels": "Import Models",

View File

@ -0,0 +1,49 @@
import { Button, Flex, FormControl, FormErrorMessage, FormLabel, Input } from '@invoke-ai/ui-library';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { useLazyGetHuggingFaceModelsQuery } from 'services/api/endpoints/models';
import { HuggingFaceResults } from './HuggingFaceResults';
export const HuggingFaceForm = () => {
const [huggingFaceRepo, setHuggingFaceRepo] = useState('');
const [errorMessage, setErrorMessage] = useState('');
const { t } = useTranslation();
const [_getHuggingFaceModels, { isLoading, data }] = useLazyGetHuggingFaceModelsQuery();
const scanFolder = useCallback(async () => {
_getHuggingFaceModels(huggingFaceRepo).catch((error) => {
if (error) {
setErrorMessage(error.data.detail);
}
});
}, [_getHuggingFaceModels, huggingFaceRepo]);
const handleSetHuggingFaceRepo: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setHuggingFaceRepo(e.target.value);
setErrorMessage('');
}, []);
return (
<Flex flexDir="column" height="100%">
<FormControl isInvalid={!!errorMessage.length} w="full">
<Flex flexDir="column" w="full">
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<Flex direction="column" w="full">
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
<Input value={huggingFaceRepo} onChange={handleSetHuggingFaceRepo} />
</Flex>
<Button onClick={scanFolder} isLoading={isLoading} isDisabled={huggingFaceRepo.length === 0}>
{t('modelManager.addModel')}
</Button>
</Flex>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex>
</FormControl>
{data && <HuggingFaceResults results={data} />}
</Flex>
);
};

View File

@ -0,0 +1,59 @@
import { Box, Flex, IconButton, Text, Tooltip } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { IoAdd } from 'react-icons/io5';
import { useInstallModelMutation } from 'services/api/endpoints/models';
type Props = {
result: string;
};
export const HuggingFaceResultItem = ({ result }: Props) => {
const { t } = useTranslation();
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const handleQuickAdd = useCallback(() => {
installModel({ source: result })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}, [installModel, result, dispatch, t]);
return (
<Flex justifyContent="space-between" w="100%">
<Flex fontSize="sm" flexDir="column">
<Text fontWeight="semibold">{result.split('/').slice(-1)[0]}</Text>
<Text variant="subtext">{result}</Text>
</Flex>
<Box>
<Tooltip label={t('modelManager.quickAdd')}>
<IconButton aria-label={t('modelManager.quickAdd')} icon={<IoAdd />} onClick={handleQuickAdd} />
</Tooltip>
</Box>
</Flex>
);
};

View File

@ -0,0 +1,119 @@
import {
Button,
Divider,
Flex,
Heading,
IconButton,
Input,
InputGroup,
InputRightElement,
} from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { addToast } from 'features/system/store/systemSlice';
import { makeToast } from 'features/system/util/makeToast';
import type { ChangeEventHandler } from 'react';
import { useCallback, useState } from 'react';
import { useTranslation } from 'react-i18next';
import { PiXBold } from 'react-icons/pi';
import { useInstallModelMutation } from 'services/api/endpoints/models';
import { HuggingFaceResultItem } from './HuggingFaceResultItem';
type HuggingFaceResultsProps = {
// results: HuggingFaceFolderResponse;
results: string[];
};
export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
const { t } = useTranslation();
const [searchTerm, setSearchTerm] = useState('');
const dispatch = useAppDispatch();
const [installModel] = useInstallModelMutation();
const handleSearch: ChangeEventHandler<HTMLInputElement> = useCallback((e) => {
setSearchTerm(e.target.value.trim());
}, []);
const clearSearch = useCallback(() => {
setSearchTerm('');
}, []);
const handleAddAll = useCallback(() => {
for (const result of results) {
installModel({ source: result })
.unwrap()
.then((_) => {
dispatch(
addToast(
makeToast({
title: t('toast.modelAddedSimple'),
status: 'success',
})
)
);
})
.catch((error) => {
if (error) {
dispatch(
addToast(
makeToast({
title: `${error.data.detail} `,
status: 'error',
})
)
);
}
});
}
}, [results, installModel, dispatch, t]);
return (
<>
<Divider mt={4} />
<Flex flexDir="column" gap={2} mt={4} height="100%">
<Flex justifyContent="space-between" alignItems="center">
<Heading fontSize="md" as="h4">
{t('modelManager.availableModels')}
</Heading>
<Flex alignItems="center" gap="4">
<Button onClick={handleAddAll} isDisabled={results.length === 0}>
{t('modelManager.addAll')}
</Button>
<InputGroup maxW="300px" size="xs">
<Input
placeholder={t('modelManager.search')}
value={searchTerm}
data-testid="board-search-input"
onChange={handleSearch}
size="xs"
/>
{searchTerm && (
<InputRightElement h="full" pe={2}>
<IconButton
size="sm"
variant="link"
aria-label={t('boards.clearSearch')}
icon={<PiXBold />}
onClick={clearSearch}
/>
</InputRightElement>
)}
</InputGroup>
</Flex>
</Flex>
<Flex height="100%" layerStyle="third" borderRadius="base" p={4} mt={4} mb={4}>
<ScrollableContent>
<Flex flexDir="column" gap={3}>
{results.map((result) => (
<HuggingFaceResultItem key={result} result={result} />
))}
</Flex>
</ScrollableContent>
</Flex>
</Flex>
</>
);
};

View File

@ -1,6 +1,7 @@
import { Box, Flex, Heading, Tab, TabList, TabPanel, TabPanels, Tabs } from '@invoke-ai/ui-library';
import { useTranslation } from 'react-i18next';
import { HuggingFaceForm } from './AddModelPanel/HuggingFaceFolder/HuggingFaceForm';
import { InstallModelForm } from './AddModelPanel/InstallModelForm';
import { ModelInstallQueue } from './AddModelPanel/ModelInstallQueue/ModelInstallQueue';
import { ScanModelsForm } from './AddModelPanel/ScanFolder/ScanFolderForm';
@ -16,12 +17,16 @@ export const InstallModels = () => {
<Tabs variant="collapse" height="100%">
<TabList>
<Tab>{t('common.simple')}</Tab>
<Tab>{t('modelManager.huggingFace')}</Tab>
<Tab>{t('modelManager.scan')}</Tab>
</TabList>
<TabPanels p={3} height="100%">
<TabPanel>
<InstallModelForm />
</TabPanel>
<TabPanel height="100%">
<HuggingFaceForm />
</TabPanel>
<TabPanel height="100%">
<ScanModelsForm />
</TabPanel>

View File

@ -258,6 +258,13 @@ export const modelsApi = api.injectEndpoints({
};
},
}),
getHuggingFaceModels: build.query<string[], string>({
query: (hugging_face_repo) => {
return {
url: buildModelsUrl(`hugging_face?hugging_face_repo=${hugging_face_repo}`),
};
},
}),
listModelInstalls: build.query<ListModelInstallsResponse, void>({
query: () => {
return {
@ -381,6 +388,7 @@ export const {
useConvertModelMutation,
useSyncModelsMutation,
useLazyScanFolderQuery,
useLazyGetHuggingFaceModelsQuery,
useListModelInstallsQuery,
useCancelModelInstallMutation,
usePruneCompletedModelInstallsMutation,