mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ui consistency, moved is_diffusers logic to backend, extended HuggingFaceMetadata, removed logic from service
This commit is contained in:
parent
2a300ecada
commit
d0800c4888
@ -29,6 +29,8 @@ from invokeai.backend.model_manager.config import (
|
|||||||
ModelType,
|
ModelType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
|
from invokeai.backend.model_manager.metadata.fetch.huggingface import HuggingFaceMetadataFetch
|
||||||
|
from invokeai.backend.model_manager.metadata.metadata_base import ModelMetadataWithFiles, UnknownMetadataException
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
@ -246,6 +248,11 @@ async def scan_for_models(
|
|||||||
return scan_results
|
return scan_results
|
||||||
|
|
||||||
|
|
||||||
|
class HuggingFaceModels(BaseModel):
|
||||||
|
urls: List[AnyHttpUrl] | None = Field(description="URLs for all checkpoint format models in the metadata")
|
||||||
|
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model")
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.get(
|
@model_manager_router.get(
|
||||||
"/hugging_face",
|
"/hugging_face",
|
||||||
operation_id="get_hugging_face_models",
|
operation_id="get_hugging_face_models",
|
||||||
@ -254,24 +261,25 @@ async def scan_for_models(
|
|||||||
400: {"description": "Invalid hugging face repo"},
|
400: {"description": "Invalid hugging face repo"},
|
||||||
},
|
},
|
||||||
status_code=200,
|
status_code=200,
|
||||||
response_model=List[AnyHttpUrl],
|
response_model=HuggingFaceModels,
|
||||||
)
|
)
|
||||||
async def get_hugging_face_models(
|
async def get_hugging_face_models(
|
||||||
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
|
hugging_face_repo: str = Query(description="Hugging face repo to search for models", default=None),
|
||||||
) -> List[AnyHttpUrl]:
|
) -> HuggingFaceModels:
|
||||||
get_hugging_face_models = ApiDependencies.invoker.services.model_manager.install.get_hugging_face_models
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = get_hugging_face_models(
|
metadata = HuggingFaceMetadataFetch().from_id(hugging_face_repo)
|
||||||
source=hugging_face_repo,
|
except UnknownMetadataException:
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
detail=f"{e}",
|
detail="No HuggingFace repository found",
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
|
||||||
|
return HuggingFaceModels(
|
||||||
|
urls=metadata.ckpt_urls,
|
||||||
|
is_diffusers=metadata.is_diffusers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@model_manager_router.patch(
|
@model_manager_router.patch(
|
||||||
|
@ -405,19 +405,6 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_hugging_face_models(
|
|
||||||
self,
|
|
||||||
source: str,
|
|
||||||
) -> List[AnyHttpUrl]:
|
|
||||||
"""Get the available models in a HuggingFace repo.
|
|
||||||
|
|
||||||
:param source: HuggingFace repo string
|
|
||||||
|
|
||||||
This will get the urls for the available models in the indicated,
|
|
||||||
repo, and return them as a list of AnyHttpUrl strings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||||
|
@ -38,7 +38,7 @@ from invokeai.backend.model_manager.metadata import (
|
|||||||
ModelMetadataWithFiles,
|
ModelMetadataWithFiles,
|
||||||
RemoteModelFile,
|
RemoteModelFile,
|
||||||
)
|
)
|
||||||
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata, UnknownMetadataException
|
from invokeai.backend.model_manager.metadata.metadata_base import HuggingFaceMetadata
|
||||||
from invokeai.backend.model_manager.probe import ModelProbe
|
from invokeai.backend.model_manager.probe import ModelProbe
|
||||||
from invokeai.backend.model_manager.search import ModelSearch
|
from invokeai.backend.model_manager.search import ModelSearch
|
||||||
from invokeai.backend.util import Chdir, InvokeAILogger
|
from invokeai.backend.util import Chdir, InvokeAILogger
|
||||||
@ -233,29 +233,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._install_jobs.append(install_job)
|
self._install_jobs.append(install_job)
|
||||||
return install_job
|
return install_job
|
||||||
|
|
||||||
def get_hugging_face_models(self, source: str) -> List[AnyHttpUrl]:
|
|
||||||
# Add user's cached access token to HuggingFace requests
|
|
||||||
access_token = HfFolder.get_token()
|
|
||||||
if not access_token:
|
|
||||||
self._logger.info("No HuggingFace access token present; some models may not be downloadable.")
|
|
||||||
|
|
||||||
try:
|
|
||||||
metadata = HuggingFaceMetadataFetch(self._session).from_id(source)
|
|
||||||
except UnknownMetadataException:
|
|
||||||
raise ValueError("No HuggingFace repository found")
|
|
||||||
|
|
||||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
|
||||||
urls: List[AnyHttpUrl] = []
|
|
||||||
|
|
||||||
for file in metadata.files:
|
|
||||||
if str(file.url).endswith(
|
|
||||||
(".safetensors", ".bin", ".onnx", ".xml", ".pth", ".pt", ".ckpt", ".msgpack", "model_index.json")
|
|
||||||
):
|
|
||||||
urls.append(file.url)
|
|
||||||
|
|
||||||
self._logger.info(f"here are the metadata files {metadata.files}")
|
|
||||||
return urls
|
|
||||||
|
|
||||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
|
@ -90,8 +90,35 @@ class HuggingFaceMetadataFetch(ModelMetadataFetchBase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# diffusers models have a `model_index.json` file
|
||||||
|
is_diffusers = any(str(f.url).endswith("model_index.json") for f in files)
|
||||||
|
|
||||||
|
# These URLs will be exposed to the user - I think these are the only file types we fully support
|
||||||
|
ckpt_urls = (
|
||||||
|
None
|
||||||
|
if is_diffusers
|
||||||
|
else [
|
||||||
|
f.url
|
||||||
|
for f in files
|
||||||
|
if str(f.url).endswith(
|
||||||
|
(
|
||||||
|
".safetensors",
|
||||||
|
".bin",
|
||||||
|
".pth",
|
||||||
|
".pt",
|
||||||
|
".ckpt",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
return HuggingFaceMetadata(
|
return HuggingFaceMetadata(
|
||||||
id=model_info.id, name=name, files=files, api_response=json.dumps(model_info.__dict__, default=str)
|
id=model_info.id,
|
||||||
|
name=name,
|
||||||
|
files=files,
|
||||||
|
api_response=json.dumps(model_info.__dict__, default=str),
|
||||||
|
is_diffusers=is_diffusers,
|
||||||
|
ckpt_urls=ckpt_urls,
|
||||||
)
|
)
|
||||||
|
|
||||||
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
def from_url(self, url: AnyHttpUrl) -> AnyModelRepoMetadata:
|
||||||
|
@ -84,6 +84,10 @@ class HuggingFaceMetadata(ModelMetadataWithFiles):
|
|||||||
type: Literal["huggingface"] = "huggingface"
|
type: Literal["huggingface"] = "huggingface"
|
||||||
id: str = Field(description="The HF model id")
|
id: str = Field(description="The HF model id")
|
||||||
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
api_response: Optional[str] = Field(description="Response from the HF API as stringified JSON", default=None)
|
||||||
|
is_diffusers: bool = Field(description="Whether the metadata is for a Diffusers format model", default=False)
|
||||||
|
ckpt_urls: Optional[List[AnyHttpUrl]] = Field(
|
||||||
|
description="URLs for all checkpoint format models in the metadata", default=None
|
||||||
|
)
|
||||||
|
|
||||||
def download_urls(
|
def download_urls(
|
||||||
self,
|
self,
|
||||||
|
@ -53,11 +53,11 @@ export const HuggingFaceForm = () => {
|
|||||||
_getHuggingFaceModels(huggingFaceRepo)
|
_getHuggingFaceModels(huggingFaceRepo)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.then((response) => {
|
.then((response) => {
|
||||||
if (response.some((result: string) => result.endsWith('model_index.json'))) {
|
if (response.is_diffusers) {
|
||||||
handleInstallModel(huggingFaceRepo);
|
handleInstallModel(huggingFaceRepo);
|
||||||
setDisplayResults(false);
|
setDisplayResults(false);
|
||||||
} else if (response.length === 1 && response[0]) {
|
} else if (response.urls?.length === 1 && response.urls[0]) {
|
||||||
handleInstallModel(response[0]);
|
handleInstallModel(response.urls[0]);
|
||||||
setDisplayResults(false);
|
setDisplayResults(false);
|
||||||
} else {
|
} else {
|
||||||
setDisplayResults(true);
|
setDisplayResults(true);
|
||||||
@ -75,26 +75,27 @@ export const HuggingFaceForm = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" height="100%">
|
<Flex flexDir="column" height="100%">
|
||||||
<FormControl isInvalid={!!errorMessage.length} w="full">
|
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical">
|
||||||
<Flex flexDir="column" w="full">
|
|
||||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
|
||||||
<Flex direction="column" w="full">
|
|
||||||
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
|
<FormLabel>{t('modelManager.huggingFaceRepoID')}</FormLabel>
|
||||||
|
<Flex gap={3} alignItems="center" w="full">
|
||||||
<Input
|
<Input
|
||||||
placeholder={t('modelManager.huggingFacePlaceholder')}
|
placeholder={t('modelManager.huggingFacePlaceholder')}
|
||||||
value={huggingFaceRepo}
|
value={huggingFaceRepo}
|
||||||
onChange={handleSetHuggingFaceRepo}
|
onChange={handleSetHuggingFaceRepo}
|
||||||
/>
|
/>
|
||||||
</Flex>
|
<Button
|
||||||
|
onClick={getModels}
|
||||||
<Button onClick={getModels} isLoading={isLoading} isDisabled={huggingFaceRepo.length === 0}>
|
isLoading={isLoading}
|
||||||
|
isDisabled={huggingFaceRepo.length === 0}
|
||||||
|
size="sm"
|
||||||
|
flexShrink={0}
|
||||||
|
>
|
||||||
{t('modelManager.addModel')}
|
{t('modelManager.addModel')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
{data && displayResults && <HuggingFaceResults results={data} />}
|
{data && data.urls && displayResults && <HuggingFaceResults results={data.urls} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
|
@ -77,7 +77,7 @@ export const HuggingFaceResults = ({ results }: HuggingFaceResultsProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Divider mt={4} />
|
<Divider mt={6} />
|
||||||
<Flex flexDir="column" gap={2} mt={4} height="100%">
|
<Flex flexDir="column" gap={2} mt={4} height="100%">
|
||||||
<Flex justifyContent="space-between" alignItems="center">
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
<Heading fontSize="md" as="h4">
|
<Heading fontSize="md" as="h4">
|
||||||
|
@ -36,24 +36,21 @@ export const ScanModelsForm = () => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<Flex flexDir="column" height="100%">
|
<Flex flexDir="column" height="100%">
|
||||||
<FormControl isInvalid={!!errorMessage.length} w="full">
|
<FormControl isInvalid={!!errorMessage.length} w="full" orientation="vertical">
|
||||||
<Flex flexDir="column" w="full">
|
|
||||||
<Flex gap={2} alignItems="flex-end" justifyContent="space-between">
|
|
||||||
<Flex direction="column" w="full">
|
|
||||||
<FormLabel>{t('common.folder')}</FormLabel>
|
<FormLabel>{t('common.folder')}</FormLabel>
|
||||||
|
<Flex gap={3} alignItems="center" w="full">
|
||||||
<Input placeholder={t('modelManager.scanPlaceholder')} value={scanPath} onChange={handleSetScanPath} />
|
<Input placeholder={t('modelManager.scanPlaceholder')} value={scanPath} onChange={handleSetScanPath} />
|
||||||
</Flex>
|
|
||||||
|
|
||||||
<Button
|
<Button
|
||||||
onClick={scanFolder}
|
onClick={scanFolder}
|
||||||
isLoading={isLoading}
|
isLoading={isLoading}
|
||||||
isDisabled={scanPath === undefined || scanPath.length === 0}
|
isDisabled={scanPath === undefined || scanPath.length === 0}
|
||||||
|
size="sm"
|
||||||
|
flexShrink={0}
|
||||||
>
|
>
|
||||||
{t('modelManager.scanFolder')}
|
{t('modelManager.scanFolder')}
|
||||||
</Button>
|
</Button>
|
||||||
</Flex>
|
</Flex>
|
||||||
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
{!!errorMessage.length && <FormErrorMessage>{errorMessage}</FormErrorMessage>}
|
||||||
</Flex>
|
|
||||||
</FormControl>
|
</FormControl>
|
||||||
{data && <ScanModelsResults results={data} />}
|
{data && <ScanModelsResults results={data} />}
|
||||||
</Flex>
|
</Flex>
|
||||||
|
@ -80,7 +80,7 @@ export const ScanModelsResults = ({ results }: ScanModelResultsProps) => {
|
|||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<Divider mt={4} />
|
<Divider mt={6} />
|
||||||
<Flex flexDir="column" gap={2} mt={4} height="100%">
|
<Flex flexDir="column" gap={2} mt={4} height="100%">
|
||||||
<Flex justifyContent="space-between" alignItems="center">
|
<Flex justifyContent="space-between" alignItems="center">
|
||||||
<Heading fontSize="md" as="h4">
|
<Heading fontSize="md" as="h4">
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user