ui consistency, moved is_diffusers logic to backend, extended HuggingFaceMetadata, removed logic from service

This commit is contained in:
Jennifer Player 2024-03-12 21:00:14 -04:00 committed by psychedelicious
parent 2a300ecada
commit d0800c4888
10 changed files with 116 additions and 91 deletions

View File

@ -29,6 +29,8 @@ from invokeai.backend.model_manager.config import (
ModelType, ModelType,
SubModelType, SubModelType,
) )
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -246,6 +248,11 @@ async def scan_for_models(
return scan_results return scan_results
class HuggingFaceModels(BaseModel):
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
@model_manager_router.get( @model_manager_router.get(
"/hugging_face", "/hugging_face",
operation_id="get_hugging_face_models", operation_id="get_hugging_face_models",
@ -254,24 +261,25 @@ async def scan_for_models(
400: {"description": "Invalid hugging face repo"}, 400: {"description": "Invalid hugging face repo"},
}, },
status_code=200, status_code=200,
response_model=List[AnyHttpUrl], response_model=HuggingFaceModels,
) )
async def get_hugging_face_models( async def get_hugging_face_models(
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None), hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
) -> List[AnyHttpUrl]: ) -> HuggingFaceModels:
get_hugging_face_models = ApiDependencies.invoker.services.model_manager.install.get_hugging_face_models
try: try:
result = get_hugging_face_models( metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
source=hugging_face_repo, except UnknownMetadataException:
)
except ValueError as e:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"{e}", detail="No HuggingFace repository found",
) )
return result assert isinstance(metadata, ModelMetadataWithFiles)
return HuggingFaceModels(
urls=metadata.ckpt_urls,
is_diffusers=metadata.is_diffusers,
)
@model_manager_router.patch( @model_manager_router.patch(

View File

@ -405,19 +405,6 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod
def get_hugging_face_models(
self,
source: str,
) -> List[AnyHttpUrl]:
"""Get the available models in a HuggingFace repo.
:param source: HuggingFace repo string
This will get the urls for the available models in the indicated,
repo, and return them as a list of AnyHttpUrl strings.
"""
@abstractmethod @abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source.""" """Return the ModelInstallJob(s) corresponding to the provided source."""

View File

@ -38,7 +38,7 @@ from invokeai.backend.model_manager.metadata import (
ModelMetadataWithFiles, ModelMetadataWithFiles,
RemoteModelFile, RemoteModelFile,
) )
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata, UnknownMetadataException from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger from invokeai.backend.util import Chdir, InvokeAILogger
@ -233,29 +233,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_jobs.append(install_job) self._install_jobs.append(install_job)
return install_job return install_job
def get_hugging_face_models(self, source: str) -> List[AnyHttpUrl]:
# Add user's cached access token to HuggingFace requests
access_token = HfFolder.get_token()
if not access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
try:
metadata = HuggingFaceMetadataFetch(self._session).from_id(source)
except UnknownMetadataException:
raise ValueError("No HuggingFace repository found")
assert isinstance(metadata, ModelMetadataWithFiles)
urls: List[AnyHttpUrl] = []
for file in metadata.files:
if str(file.url).endswith(
(".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack", "model_index.json")
):
urls.append(file.url)
self._logger.info(f"here are the metadata files {metadata.files}")
return urls
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs return self._install_jobs

View File

@ -90,8 +90,35 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
) )
) )
# diffusers models have a `model_index.json` file
is_diffusers = any(str(f.url).endswith("model_index.json") for f in files)
# These URLs will be exposed to the user - I think these are the only file types we fully support
ckpt_urls = (
None
if is_diffusers
else [
f.url
for f in files
if str(f.url).endswith(
(
".safetensors",
".bin",
".pth",
".pt",
".ckpt",
)
)
]
)
return HuggingFaceMetadata( return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str) id=model_info.id,
name=name,
files=files,
api_response=json.dumps(model_info.__dict__, default=str),
is_diffusers=is_diffusers,
ckpt_urls=ckpt_urls,
) )
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata: def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -84,6 +84,10 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
type: Literal["huggingface"] = "huggingface" type: Literal["huggingface"] = "huggingface"
id: str = Field(description="The HF model id") id: str = Field(description="The HF model id")
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None) api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model", default=False)
ckpt_urls: Optional[List[AnyHttpUrl]] = Field(
description="URLs for all checkpoint format models in the metadata", default=None
)
def download_urls( def download_urls(
self, self,

View File

@ -53,11 +53,11 @@ export const HuggingFaceForm = () => {
_getHuggingFaceModels(huggingFaceRepo) _getHuggingFaceModels(huggingFaceRepo)
.unwrap() .unwrap()
.then((response) => { .then((response) => {
if (response.some((result: string) => result.endsWith('model_index.json'))) { if (response.is_diffusers) {
handleInstallModel(huggingFaceRepo); handleInstallModel(huggingFaceRepo);
setDisplayResults(false); setDisplayResults(false);
} else if (response.length === 1 && response[0]) { } else if (response.urls?.length === 1 && response.urls[0]) {
handleInstallModel(response[0]); handleInstallModel(response.urls[0]);
setDisplayResults(false); setDisplayResults(false);
} else { } else {
setDisplayResults(true); setDisplayResults(true);
@ -75,26 +75,27 @@ export const HuggingFaceForm = () => {
return ( return (
<Flex flexDir="column" height="100%"> <Flex flexDir="column" height="100%">
<FormControl isInvalid={!!errorMessage.length} w="full"> <FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical">
<Flex flexDir="column" w="full">
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<Flex direction="column" w="full">
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel> <FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input <Input
placeholder={t('modelManager.huggingFacePlaceholder')} placeholder={t('modelManager.huggingFacePlaceholder')}
value={huggingFaceRepo} value={huggingFaceRepo}
onChange={handleSetHuggingFaceRepo} onChange={handleSetHuggingFaceRepo}
/> />
</Flex> <Button
onClick={getModels}
<Button onClick={getModels} isLoading={isLoading} isDisabled={huggingFaceRepo.length === 0}> isLoading={isLoading}
isDisabled={huggingFaceRepo.length === 0}
size="sm"
flexShrink={0}
>
{t('modelManager.addModel')} {t('modelManager.addModel')}
</Button> </Button>
</Flex> </Flex>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>} {!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex>
</FormControl> </FormControl>
{data && displayResults && <HuggingFaceResults results={data} />} {data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
</Flex> </Flex>
); );
}; };

View File

@ -77,7 +77,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
return ( return (
<> <>
<Divider mt={4} /> <Divider mt={6} />
<Flex flexDir="column" gap={2} mt={4} height="100%"> <Flex flexDir="column" gap={2} mt={4} height="100%">
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Heading fontSize="md" as="h4"> <Heading fontSize="md" as="h4">

View File

@ -36,24 +36,21 @@ export const ScanModelsForm = () => {
return ( return (
<Flex flexDir="column" height="100%"> <Flex flexDir="column" height="100%">
<FormControl isInvalid={!!errorMessage.length} w="full"> <FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical">
<Flex flexDir="column" w="full">
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<Flex direction="column" w="full">
<FormLabel>{t('common.folder')}</FormLabel> <FormLabel>{t('common.folder')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input placeholder={t('modelManager.scanPlaceholder')} value={scanPath} onChange={handleSetScanPath} /> <Input placeholder={t('modelManager.scanPlaceholder')} value={scanPath} onChange={handleSetScanPath} />
</Flex>
<Button <Button
onClick={scanFolder} onClick={scanFolder}
isLoading={isLoading} isLoading={isLoading}
isDisabled={scanPath === undefined || scanPath.length === 0} isDisabled={scanPath === undefined || scanPath.length === 0}
size="sm"
flexShrink={0}
> >
{t('modelManager.scanFolder')} {t('modelManager.scanFolder')}
</Button> </Button>
</Flex> </Flex>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>} {!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex>
</FormControl> </FormControl>
{data && <ScanModelsResults results={data} />} {data && <ScanModelsResults results={data} />}
</Flex> </Flex>

View File

@ -80,7 +80,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
return ( return (
<> <>
<Divider mt={4} /> <Divider mt={6} />
<Flex flexDir="column" gap={2} mt={4} height="100%"> <Flex flexDir="column" gap={2} mt={4} height="100%">
<Flex justifyContent="space-between" alignItems="center"> <Flex justifyContent="space-between" alignItems="center">
<Heading fontSize="md" as="h4"> <Heading fontSize="md" as="h4">

File diff suppressed because one or more lines are too long