ui consistency, moved is_diffusers logic to backend, extended HuggingFaceMetadata, removed logic from service

This commit is contained in:
Jennifer Player 2024-03-12 21:00:14 -04:00 committed by psychedelicious
parent 2a300ecada
commit d0800c4888
10 changed files with 116 additions and 91 deletions

View File

@ -29,6 +29,8 @@ from invokeai.backend.model_manager.config import (
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
from invokeai.backend.model_manager.search import ModelSearch
from ..dependencies import ApiDependencies
@ -246,6 +248,11 @@ async def scan_for_models(
return scan_results
class HuggingFaceModels(BaseModel):
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
@model_manager_router.get(
"/hugging_face",
operation_id="get_hugging_face_models",
@ -254,24 +261,25 @@ async def scan_for_models(
400: {"description": "Invalid hugging face repo"},
},
status_code=200,
response_model=List[AnyHttpUrl],
response_model=HuggingFaceModels,
)
async def get_hugging_face_models(
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
) -> List[AnyHttpUrl]:
get_hugging_face_models = ApiDependencies.invoker.services.model_manager.install.get_hugging_face_models
) -> HuggingFaceModels:
try:
result = get_hugging_face_models(
source=hugging_face_repo,
)
except ValueError as e:
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
except UnknownMetadataException:
raise HTTPException(
status_code=400,
detail=f"{e}",
detail="No HuggingFace repository found",
)
return result
assert isinstance(metadata, ModelMetadataWithFiles)
return HuggingFaceModels(
urls=metadata.ckpt_urls,
is_diffusers=metadata.is_diffusers,
)
@model_manager_router.patch(

View File

@ -405,19 +405,6 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def get_hugging_face_models(
self,
source: str,
) -> List[AnyHttpUrl]:
"""Get the available models in a HuggingFace repo.
:param source: HuggingFace repo string
This will get the urls for the available models in the indicated,
repo, and return them as a list of AnyHttpUrl strings.
"""
@abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""

View File

@ -38,7 +38,7 @@ from invokeai.backend.model_manager.metadata import (
ModelMetadataWithFiles,
RemoteModelFile,
)
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata, UnknownMetadataException
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
from invokeai.backend.model_manager.probe import ModelProbe
from invokeai.backend.model_manager.search import ModelSearch
from invokeai.backend.util import Chdir, InvokeAILogger
@ -233,29 +233,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_jobs.append(install_job)
return install_job
def get_hugging_face_models(self, source: str) -> List[AnyHttpUrl]:
# Add user's cached access token to HuggingFace requests
access_token = HfFolder.get_token()
if not access_token:
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
try:
metadata = HuggingFaceMetadataFetch(self._session).from_id(source)
except UnknownMetadataException:
raise ValueError("No HuggingFace repository found")
assert isinstance(metadata, ModelMetadataWithFiles)
urls: List[AnyHttpUrl] = []
for file in metadata.files:
if str(file.url).endswith(
(".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack", "model_index.json")
):
urls.append(file.url)
self._logger.info(f"here are the metadata files {metadata.files}")
return urls
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs

View File

@ -90,8 +90,35 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
)
)
# diffusers models have a `model_index.json` file
is_diffusers = any(str(f.url).endswith("model_index.json") for f in files)
# These URLs will be exposed to the user - I think these are the only file types we fully support
ckpt_urls = (
None
if is_diffusers
else [
f.url
for f in files
if str(f.url).endswith(
(
".safetensors",
".bin",
".pth",
".pt",
".ckpt",
)
)
]
)
return HuggingFaceMetadata(
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
id=model_info.id,
name=name,
files=files,
api_response=json.dumps(model_info.__dict__, default=str),
is_diffusers=is_diffusers,
ckpt_urls=ckpt_urls,
)
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:

View File

@ -84,6 +84,10 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
type: Literal["huggingface"] = "huggingface"
id: str = Field(description="The HF model id")
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model", default=False)
ckpt_urls: Optional[List[AnyHttpUrl]] = Field(
description="URLs for all checkpoint format models in the metadata", default=None
)
def download_urls(
self,

View File

@ -53,11 +53,11 @@ export const HuggingFaceForm = () => {
_getHuggingFaceModels(huggingFaceRepo)
.unwrap()
.then((response) => {
if (response.some((result: string) => result.endsWith('model_index.json'))) {
if (response.is_diffusers) {
handleInstallModel(huggingFaceRepo);
setDisplayResults(false);
} else if (response.length === 1 && response[0]) {
handleInstallModel(response[0]);
} else if (response.urls?.length === 1 && response.urls[0]) {
handleInstallModel(response.urls[0]);
setDisplayResults(false);
} else {
setDisplayResults(true);
@ -75,26 +75,27 @@ export const HuggingFaceForm = () => {
return (
<Flex flexDir="column" height="100%">
<FormControl isInvalid={!!errorMessage.length} w="full">
<Flex flexDir="column" w="full">
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<Flex direction="column" w="full">
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical">
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input
placeholder={t('modelManager.huggingFacePlaceholder')}
value={huggingFaceRepo}
onChange={handleSetHuggingFaceRepo}
/>
</Flex>
<Button onClick={getModels} isLoading={isLoading} isDisabled={huggingFaceRepo.length === 0}>
<Button
onClick={getModels}
isLoading={isLoading}
isDisabled={huggingFaceRepo.length === 0}
size="sm"
flexShrink={0}
>
{t('modelManager.addModel')}
</Button>
</Flex>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex>
</FormControl>
{data && displayResults && <HuggingFaceResults results={data} />}
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
</Flex>
);
};

View File

@ -77,7 +77,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
return (
<>
<Divider mt={4} />
<Divider mt={6} />
<Flex flexDir="column" gap={2} mt={4} height="100%">
<Flex justifyContent="space-between" alignItems="center">
<Heading fontSize="md" as="h4">

View File

@ -36,24 +36,21 @@ export const ScanModelsForm = () => {
return (
<Flex flexDir="column" height="100%">
<FormControl isInvalid={!!errorMessage.length} w="full">
<Flex flexDir="column" w="full">
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
<Flex direction="column" w="full">
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical">
<FormLabel>{t('common.folder')}</FormLabel>
<Flex gap={3} alignItems="center" w="full">
<Input placeholder={t('modelManager.scanPlaceholder')} value={scanPath} onChange={handleSetScanPath} />
</Flex>
<Button
onClick={scanFolder}
isLoading={isLoading}
isDisabled={scanPath === undefined || scanPath.length === 0}
size="sm"
flexShrink={0}
>
{t('modelManager.scanFolder')}
</Button>
</Flex>
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
</Flex>
</FormControl>
{data && <ScanModelsResults results={data} />}
</Flex>

View File

@ -80,7 +80,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
return (
<>
<Divider mt={4} />
<Divider mt={6} />
<Flex flexDir="column" gap={2} mt={4} height="100%">
<Flex justifyContent="space-between" alignItems="center">
<Heading fontSize="md" as="h4">

File diff suppressed because one or more lines are too long