mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
added route for installing huggingface model from model marketplace
This commit is contained in:
parent
a3cb5da130
commit
aae318425d
@ -9,7 +9,7 @@ from copy import deepcopy
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from fastapi import Body, Path, Query, Response, UploadFile
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.responses import FileResponse, HTMLResponse
|
||||
from fastapi.routing import APIRouter
|
||||
from PIL import Image
|
||||
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
@ -501,6 +501,80 @@ async def install_model(
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
return result
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install/huggingface",
|
||||
operation_id="install_hugging_face_model",
|
||||
responses={
|
||||
201: {"description": "The model is being installed"},
|
||||
400: {"description": "Bad request"},
|
||||
409: {"description": "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_class=HTMLResponse
|
||||
)
|
||||
async def install_hugging_face_model(
|
||||
source: str = Query(description="Hugging Face repo_id to install"),
|
||||
) -> HTMLResponse:
|
||||
"""Install a Hugging Face model using a string identifier."""
|
||||
|
||||
def generate_html(message: str) -> str:
|
||||
return f"""
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {{
|
||||
text-align: center;
|
||||
margin-top: 50px;
|
||||
}}
|
||||
.message-box {{
|
||||
display: inline-block;
|
||||
padding: 20px;
|
||||
border: 1px solid #ccc;
|
||||
border-radius: 5px;
|
||||
background-color: #f9f9f9;
|
||||
box-shadow: 0px 0px 10px rgba(0, 0, 0, 0.1);
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="message-box">
|
||||
<p>{message}</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
try:
|
||||
metadata = HuggingFaceMetadataFetch().from_id(source)
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
message = "Your Hugging Face model is installing now. You can close this tab and check the Model Manager for installation progress."
|
||||
except UnknownMetadataException:
|
||||
message = "No HuggingFace repository found with that repo id."
|
||||
return HTMLResponse(content=generate_html(message), status_code=400)
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
installer = ApiDependencies.invoker.services.model_manager.install
|
||||
|
||||
if metadata.is_diffusers:
|
||||
installer.heuristic_import(
|
||||
source=source,
|
||||
inplace=False,
|
||||
)
|
||||
elif metadata.ckpt_urls is not None and len(metadata.ckpt_urls) == 1:
|
||||
installer.heuristic_import(
|
||||
source=str(metadata.ckpt_urls[0]),
|
||||
inplace=False,
|
||||
)
|
||||
else:
|
||||
message = "This HuggingFace repo has multiple models. Please use the Model Manager to install this."
|
||||
except Exception as e:
|
||||
logger.error(str(e))
|
||||
message = "There was an error with installing this model. Please use the Model Manager to install this."
|
||||
|
||||
return HTMLResponse(content=generate_html(message), status_code=201)
|
||||
|
||||
|
||||
@model_manager_router.get(
|
||||
"/install",
|
||||
|
Loading…
Reference in New Issue
Block a user