mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ui consistency, moved is_diffusers logic to backend, extended HuggingFaceMetadata, removed logic from service
This commit is contained in:
parent
2a300ecada
commit
d0800c4888
@ -29,6 +29,8 @@ from invokeai.backend.model_manager.config import (
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
@ -246,6 +248,11 @@ async def scan_for_models(
|
||||
return scan_results
|
||||
|
||||
|
||||
class HuggingFaceModels(BaseModel):
|
||||
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
|
||||
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/hugging_face",
|
||||
operation_id="get_hugging_face_models",
|
||||
@ -254,24 +261,25 @@ async def scan_for_models(
|
||||
400: {"description": "Invalid hugging face repo"},
|
||||
},
|
||||
status_code=200,
|
||||
response_model=List[AnyHttpUrl],
|
||||
response_model=HuggingFaceModels,
|
||||
)
|
||||
async def get_hugging_face_models(
|
||||
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
|
||||
) -> List[AnyHttpUrl]:
|
||||
get_hugging_face_models = ApiDependencies.invoker.services.model_manager.install.get_hugging_face_models
|
||||
|
||||
) -> HuggingFaceModels:
|
||||
try:
|
||||
result = get_hugging_face_models(
|
||||
source=hugging_face_repo,
|
||||
)
|
||||
except ValueError as e:
|
||||
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
|
||||
except UnknownMetadataException:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"{e}",
|
||||
detail="No HuggingFace repository found",
|
||||
)
|
||||
|
||||
return result
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
|
||||
return HuggingFaceModels(
|
||||
urls=metadata.ckpt_urls,
|
||||
is_diffusers=metadata.is_diffusers,
|
||||
)
|
||||
|
||||
|
||||
@model_manager_router.patch(
|
||||
|
@ -405,19 +405,6 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_hugging_face_models(
|
||||
self,
|
||||
source: str,
|
||||
) -> List[AnyHttpUrl]:
|
||||
"""Get the available models in a HuggingFace repo.
|
||||
|
||||
:param source: HuggingFace repo string
|
||||
|
||||
This will get the urls for the available models in the indicated,
|
||||
repo, and return them as a list of AnyHttpUrl strings.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
@ -38,7 +38,7 @@ from invokeai.backend.model_manager.metadata import (
|
||||
ModelMetadataWithFiles,
|
||||
RemoteModelFile,
|
||||
)
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata, UnknownMetadataException
|
||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||
from invokeai.backend.model_manager.probe import ModelProbe
|
||||
from invokeai.backend.model_manager.search import ModelSearch
|
||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||
@ -233,29 +233,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_jobs.append(install_job)
|
||||
return install_job
|
||||
|
||||
def get_hugging_face_models(self, source: str) -> List[AnyHttpUrl]:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
access_token = HfFolder.get_token()
|
||||
if not access_token:
|
||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
||||
|
||||
try:
|
||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source)
|
||||
except UnknownMetadataException:
|
||||
raise ValueError("No HuggingFace repository found")
|
||||
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
urls: List[AnyHttpUrl] = []
|
||||
|
||||
for file in metadata.files:
|
||||
if str(file.url).endswith(
|
||||
(".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack", "model_index.json")
|
||||
):
|
||||
urls.append(file.url)
|
||||
|
||||
self._logger.info(f"here are the metadata files {metadata.files}")
|
||||
return urls
|
||||
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
|
@ -90,8 +90,35 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
||||
)
|
||||
)
|
||||
|
||||
# diffusers models have a `model_index.json` file
|
||||
is_diffusers = any(str(f.url).endswith("model_index.json") for f in files)
|
||||
|
||||
# These URLs will be exposed to the user - I think these are the only file types we fully support
|
||||
ckpt_urls = (
|
||||
None
|
||||
if is_diffusers
|
||||
else [
|
||||
f.url
|
||||
for f in files
|
||||
if str(f.url).endswith(
|
||||
(
|
||||
".safetensors",
|
||||
".bin",
|
||||
".pth",
|
||||
".pt",
|
||||
".ckpt",
|
||||
)
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return HuggingFaceMetadata(
|
||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
||||
id=model_info.id,
|
||||
name=name,
|
||||
files=files,
|
||||
api_response=json.dumps(model_info.__dict__, default=str),
|
||||
is_diffusers=is_diffusers,
|
||||
ckpt_urls=ckpt_urls,
|
||||
)
|
||||
|
||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||
|
@ -84,6 +84,10 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
||||
type: Literal["huggingface"] = "huggingface"
|
||||
id: str = Field(description="The HF model id")
|
||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model", default=False)
|
||||
ckpt_urls: Optional[List[AnyHttpUrl]] = Field(
|
||||
description="URLs for all checkpoint format models in the metadata", default=None
|
||||
)
|
||||
|
||||
def download_urls(
|
||||
self,
|
||||
|
@ -53,11 +53,11 @@ export const HuggingFaceForm = () => {
|
||||
_getHuggingFaceModels(huggingFaceRepo)
|
||||
.unwrap()
|
||||
.then((response) => {
|
||||
if (response.some((result: string) => result.endsWith('model_index.json'))) {
|
||||
if (response.is_diffusers) {
|
||||
handleInstallModel(huggingFaceRepo);
|
||||
setDisplayResults(false);
|
||||
} else if (response.length === 1 && response[0]) {
|
||||
handleInstallModel(response[0]);
|
||||
} else if (response.urls?.length === 1 && response.urls[0]) {
|
||||
handleInstallModel(response.urls[0]);
|
||||
setDisplayResults(false);
|
||||
} else {
|
||||
setDisplayResults(true);
|
||||
@ -75,26 +75,27 @@ export const HuggingFaceForm = () => {
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="100%">
|
||||
<FormControl isInvalid={!!errorMessage.length} w="full">
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
||||
<Flex direction="column" w="full">
|
||||
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical">
|
||||
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Input
|
||||
placeholder={t('modelManager.huggingFacePlaceholder')}
|
||||
value={huggingFaceRepo}
|
||||
onChange={handleSetHuggingFaceRepo}
|
||||
/>
|
||||
</Flex>
|
||||
|
||||
<Button onClick={getModels} isLoading={isLoading} isDisabled={huggingFaceRepo.length === 0}>
|
||||
<Button
|
||||
onClick={getModels}
|
||||
isLoading={isLoading}
|
||||
isDisabled={huggingFaceRepo.length === 0}
|
||||
size="sm"
|
||||
flexShrink={0}
|
||||
>
|
||||
{t('modelManager.addModel')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
{data && displayResults && <HuggingFaceResults results={data} />}
|
||||
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||
</Flex>
|
||||
);
|
||||
};
|
||||
|
@ -77,7 +77,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Divider mt={4} />
|
||||
<Divider mt={6} />
|
||||
<Flex flexDir="column" gap={2} mt={4} height="100%">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Heading fontSize="md" as="h4">
|
||||
|
@ -36,24 +36,21 @@ export const ScanModelsForm = () => {
|
||||
|
||||
return (
|
||||
<Flex flexDir="column" height="100%">
|
||||
<FormControl isInvalid={!!errorMessage.length} w="full">
|
||||
<Flex flexDir="column" w="full">
|
||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
||||
<Flex direction="column" w="full">
|
||||
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical">
|
||||
<FormLabel>{t('common.folder')}</FormLabel>
|
||||
<Flex gap={3} alignItems="center" w="full">
|
||||
<Input placeholder={t('modelManager.scanPlaceholder')} value={scanPath} onChange={handleSetScanPath} />
|
||||
</Flex>
|
||||
|
||||
<Button
|
||||
onClick={scanFolder}
|
||||
isLoading={isLoading}
|
||||
isDisabled={scanPath === undefined || scanPath.length === 0}
|
||||
size="sm"
|
||||
flexShrink={0}
|
||||
>
|
||||
{t('modelManager.scanFolder')}
|
||||
</Button>
|
||||
</Flex>
|
||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||
</Flex>
|
||||
</FormControl>
|
||||
{data && <ScanModelsResults results={data} />}
|
||||
</Flex>
|
||||
|
@ -80,7 +80,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<Divider mt={4} />
|
||||
<Divider mt={6} />
|
||||
<Flex flexDir="column" gap={2} mt={4} height="100%">
|
||||
<Flex justifyContent="space-between" alignItems="center">
|
||||
<Heading fontSize="md" as="h4">
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user