2023-11-05 03:03:26 +00:00
# Copyright (c) 2023 Lincoln D. Stein
""" FastAPI route for model configuration records. """
2024-03-06 03:57:05 +00:00
import io
2024-02-02 17:18:47 +00:00
import pathlib
2024-03-06 03:57:05 +00:00
import traceback
2024-03-19 04:57:16 +00:00
from copy import deepcopy
2024-06-27 21:31:28 +00:00
from tempfile import TemporaryDirectory
2024-07-23 21:41:00 +00:00
from typing import List , Optional , Type
2023-11-05 03:03:26 +00:00
2024-03-06 03:57:05 +00:00
from fastapi import Body , Path , Query , Response , UploadFile
2024-06-14 21:08:39 +00:00
from fastapi . responses import FileResponse , HTMLResponse
2023-11-05 03:03:26 +00:00
from fastapi . routing import APIRouter
2024-03-06 03:57:05 +00:00
from PIL import Image
2024-03-07 21:57:28 +00:00
from pydantic import AnyHttpUrl , BaseModel , ConfigDict , Field
2023-11-05 03:03:26 +00:00
from starlette . exceptions import HTTPException
2023-11-11 02:32:44 +00:00
from typing_extensions import Annotated
2023-11-05 03:03:26 +00:00
2024-07-03 16:04:22 +00:00
from invokeai . app . api . dependencies import ApiDependencies
2024-05-10 01:17:35 +00:00
from invokeai . app . services . model_images . model_images_common import ModelImageFileNotFoundException
2024-03-14 08:04:19 +00:00
from invokeai . app . services . model_install . model_install_common import ModelInstallJob
2023-11-13 23:12:45 +00:00
from invokeai . app . services . model_records import (
InvalidModelException ,
2024-03-22 01:14:45 +00:00
ModelRecordChanges ,
2023-11-13 23:12:45 +00:00
UnknownModelException ,
)
from invokeai . backend . model_manager . config import (
AnyModelConfig ,
BaseModelType ,
2024-03-05 07:40:17 +00:00
MainCheckpointConfig ,
2024-01-14 19:54:53 +00:00
ModelFormat ,
2023-11-13 23:12:45 +00:00
ModelType ,
)
2024-03-13 01:00:14 +00:00
from invokeai . backend . model_manager . metadata . fetch . huggingface import HuggingFaceMetadataFetch
from invokeai . backend . model_manager . metadata . metadata_base import ModelMetadataWithFiles , UnknownMetadataException
2024-02-21 16:54:02 +00:00
from invokeai . backend . model_manager . search import ModelSearch
2024-03-22 03:11:25 +00:00
from invokeai . backend . model_manager . starter_models import STARTER_MODELS , StarterModel , StarterModelWithoutDependencies
2023-11-05 03:03:26 +00:00
2024-02-18 06:27:42 +00:00
model_manager_router = APIRouter ( prefix = " /v2/models " , tags = [ " model_manager " ] )
2023-11-05 03:03:26 +00:00
2024-03-06 03:57:05 +00:00
# images are immutable; set a high max-age
IMAGE_MAX_AGE = 31536000
2024-02-14 16:10:50 +00:00
class ModelsList ( BaseModel ) :
""" Return list of configs. """
models : List [ AnyModelConfig ]
model_config = ConfigDict ( use_enum_values = True )
2024-05-10 00:42:34 +00:00
def add_cover_image_to_model_config ( config : AnyModelConfig , dependencies : Type [ ApiDependencies ] ) - > AnyModelConfig :
""" Add a cover image URL to a model configuration. """
cover_image = dependencies . invoker . services . model_images . get_url ( config . key )
config . cover_image = cover_image
return config
2024-02-14 16:10:50 +00:00
##############################################################################
# These are example inputs and outputs that are used in places where Swagger
# is unable to generate a correct example.
##############################################################################
example_model_config = {
" path " : " string " ,
" name " : " string " ,
2024-02-13 05:26:49 +00:00
" base " : " sd-1 " ,
" type " : " main " ,
2024-02-14 16:10:50 +00:00
" format " : " checkpoint " ,
2024-03-01 04:25:21 +00:00
" config_path " : " string " ,
2024-02-14 16:10:50 +00:00
" key " : " string " ,
2024-03-01 05:13:29 +00:00
" hash " : " string " ,
2024-02-14 16:10:50 +00:00
" description " : " string " ,
" source " : " string " ,
2024-03-01 04:27:41 +00:00
" converted_at " : 0 ,
2024-02-13 05:26:49 +00:00
" variant " : " normal " ,
" prediction_type " : " epsilon " ,
" repo_variant " : " fp16 " ,
2024-02-14 16:10:50 +00:00
" upcast_attention " : False ,
2024-02-13 05:26:49 +00:00
}
example_model_input = {
2024-02-14 16:10:50 +00:00
" path " : " /path/to/model " ,
2024-02-13 05:26:49 +00:00
" name " : " model_name " ,
" base " : " sd-1 " ,
" type " : " main " ,
2024-02-14 16:10:50 +00:00
" format " : " checkpoint " ,
2024-03-01 04:25:21 +00:00
" config_path " : " configs/stable-diffusion/v1-inference.yaml " ,
2024-02-13 05:26:49 +00:00
" description " : " Model description " ,
" vae " : None ,
" variant " : " normal " ,
}
2024-02-14 16:10:50 +00:00
##############################################################################
# ROUTES
##############################################################################
2024-01-14 19:54:53 +00:00
2024-02-18 06:27:42 +00:00
@model_manager_router.get (
2023-11-05 03:03:26 +00:00
" / " ,
2023-11-10 23:22:54 +00:00
operation_id = " list_model_records " ,
2023-11-05 03:03:26 +00:00
)
async def list_model_records (
2023-11-13 23:15:17 +00:00
base_models : Optional [ List [ BaseModelType ] ] = Query ( default = None , description = " Base models to include " ) ,
model_type : Optional [ ModelType ] = Query ( default = None , description = " The type of model to get " ) ,
2023-11-26 18:18:21 +00:00
model_name : Optional [ str ] = Query ( default = None , description = " Exact match on the name of the model " ) ,
2024-01-14 19:54:53 +00:00
model_format : Optional [ ModelFormat ] = Query (
2023-12-15 04:54:59 +00:00
default = None , description = " Exact match on the format of the model (e.g. ' diffusers ' ) "
) ,
2023-11-05 03:03:26 +00:00
) - > ModelsList :
""" Get a list of models. """
2024-02-10 23:09:45 +00:00
record_store = ApiDependencies . invoker . services . model_manager . store
2023-11-13 23:12:45 +00:00
found_models : list [ AnyModelConfig ] = [ ]
2023-11-13 22:05:01 +00:00
if base_models :
2023-11-05 03:03:26 +00:00
for base_model in base_models :
2023-11-26 22:13:31 +00:00
found_models . extend (
2023-12-15 04:54:59 +00:00
record_store . search_by_attr (
base_model = base_model , model_type = model_type , model_name = model_name , model_format = model_format
)
2023-11-26 22:13:31 +00:00
)
2023-11-05 03:03:26 +00:00
else :
2023-12-15 04:54:59 +00:00
found_models . extend (
record_store . search_by_attr ( model_type = model_type , model_name = model_name , model_format = model_format )
)
2024-03-06 18:15:33 +00:00
for model in found_models :
2024-05-10 00:42:34 +00:00
model = add_cover_image_to_model_config ( model , ApiDependencies )
2023-11-13 21:00:21 +00:00
return ModelsList ( models = found_models )
2023-11-05 03:03:26 +00:00
2024-02-22 22:53:23 +00:00
@model_manager_router.get (
" /get_by_attrs " ,
operation_id = " get_model_records_by_attrs " ,
response_model = AnyModelConfig ,
)
async def get_model_records_by_attrs (
name : str = Query ( description = " The name of the model " ) ,
type : ModelType = Query ( description = " The type of the model " ) ,
base : BaseModelType = Query ( description = " The base model of the model " ) ,
) - > AnyModelConfig :
""" Gets a model by its attributes. The main use of this route is to provide backwards compatibility with the old
model manager , which identified models by a combination of name , base and type . """
configs = ApiDependencies . invoker . services . model_manager . store . search_by_attr (
base_model = base , model_type = type , model_name = name
)
if not configs :
raise HTTPException ( status_code = 404 , detail = " No model found with these attributes " )
return configs [ 0 ]
2024-02-18 06:27:42 +00:00
@model_manager_router.get (
2023-11-05 03:42:44 +00:00
" /i/ {key} " ,
operation_id = " get_model_record " ,
responses = {
2024-02-13 05:26:49 +00:00
200 : {
" description " : " The model configuration was retrieved successfully " ,
2024-02-14 16:10:50 +00:00
" content " : { " application/json " : { " example " : example_model_config } } ,
2024-02-13 05:26:49 +00:00
} ,
2023-11-05 03:42:44 +00:00
400 : { " description " : " Bad request " } ,
404 : { " description " : " The model could not be found " } ,
} ,
)
async def get_model_record (
2023-11-12 19:20:32 +00:00
key : str = Path ( description = " Key of the model record to fetch. " ) ,
2023-11-05 03:42:44 +00:00
) - > AnyModelConfig :
""" Get a model record """
try :
2024-05-10 00:42:34 +00:00
config = ApiDependencies . invoker . services . model_manager . store . get_model ( key )
return add_cover_image_to_model_config ( config , ApiDependencies )
2023-11-05 03:42:44 +00:00
except UnknownModelException as e :
raise HTTPException ( status_code = 404 , detail = str ( e ) )
2024-02-24 07:46:54 +00:00
class FoundModel ( BaseModel ) :
path : str = Field ( description = " Path to the model " )
is_installed : bool = Field ( description = " Whether or not the model is already installed " )
2024-02-21 16:54:02 +00:00
@model_manager_router.get (
2024-02-22 14:08:18 +00:00
" /scan_folder " ,
operation_id = " scan_for_models " ,
2024-02-21 16:54:02 +00:00
responses = {
2024-02-22 14:08:18 +00:00
200 : { " description " : " Directory scanned successfully " } ,
400 : { " description " : " Invalid directory path " } ,
2024-02-21 16:54:02 +00:00
} ,
status_code = 200 ,
2024-02-24 07:46:54 +00:00
response_model = List [ FoundModel ] ,
2024-02-21 16:54:02 +00:00
)
2024-02-22 14:08:18 +00:00
async def scan_for_models (
scan_path : str = Query ( description = " Directory path to search for models " , default = None ) ,
2024-02-24 07:46:54 +00:00
) - > List [ FoundModel ] :
2024-02-22 14:08:18 +00:00
path = pathlib . Path ( scan_path )
2024-02-28 20:11:15 +00:00
if not scan_path or not path . is_dir ( ) :
2024-02-21 16:54:02 +00:00
raise HTTPException (
2024-02-22 14:08:18 +00:00
status_code = 400 ,
detail = f " The search path ' { scan_path } ' does not exist or is not directory " ,
2024-02-21 16:54:02 +00:00
)
search = ModelSearch ( )
try :
2024-02-24 07:46:54 +00:00
found_model_paths = search . search ( path )
models_path = ApiDependencies . invoker . services . configuration . models_path
# If the search path includes the main models directory, we need to exclude core models from the list.
# TODO(MM2): Core models should be handled by the model manager so we can determine if they are installed
# without needing to crawl the filesystem.
core_models_path = pathlib . Path ( models_path , " core " ) . resolve ( )
non_core_model_paths = [ p for p in found_model_paths if not p . is_relative_to ( core_models_path ) ]
installed_models = ApiDependencies . invoker . services . model_manager . store . search_by_attr ( )
scan_results : list [ FoundModel ] = [ ]
2024-04-03 07:56:33 +00:00
# Check if the model is installed by comparing paths, appending to the scan result.
2024-02-24 07:46:54 +00:00
for p in non_core_model_paths :
path = str ( p )
2024-04-03 07:56:33 +00:00
is_installed = any ( str ( models_path / m . path ) == path for m in installed_models )
2024-02-24 07:46:54 +00:00
found_model = FoundModel ( path = path , is_installed = is_installed )
scan_results . append ( found_model )
2024-02-21 16:54:02 +00:00
except Exception as e :
raise HTTPException (
2024-02-22 14:08:18 +00:00
status_code = 500 ,
2024-02-21 16:54:02 +00:00
detail = f " An error occurred while searching the directory: { e } " ,
)
2024-02-24 07:46:54 +00:00
return scan_results
2024-02-21 16:54:02 +00:00
2024-01-14 19:54:53 +00:00
2024-03-13 01:00:14 +00:00
class HuggingFaceModels ( BaseModel ) :
urls : List [ AnyHttpUrl ] | None = Field ( description = " URLs for all checkpoint format models in the metadata " )
is_diffusers : bool = Field ( description = " Whether the metadata is for a Diffusers format model " )
2024-03-07 21:57:28 +00:00
@model_manager_router.get (
" /hugging_face " ,
operation_id = " get_hugging_face_models " ,
responses = {
200 : { " description " : " Hugging Face repo scanned successfully " } ,
400 : { " description " : " Invalid hugging face repo " } ,
} ,
status_code = 200 ,
2024-03-13 01:00:14 +00:00
response_model = HuggingFaceModels ,
2024-03-07 21:57:28 +00:00
)
async def get_hugging_face_models (
hugging_face_repo : str = Query ( description = " Hugging face repo to search for models " , default = None ) ,
2024-03-13 01:00:14 +00:00
) - > HuggingFaceModels :
2024-03-11 15:46:49 +00:00
try :
2024-03-13 01:00:14 +00:00
metadata = HuggingFaceMetadataFetch ( ) . from_id ( hugging_face_repo )
except UnknownMetadataException :
2024-03-11 15:46:49 +00:00
raise HTTPException (
status_code = 400 ,
2024-03-13 01:00:14 +00:00
detail = " No HuggingFace repository found " ,
2024-03-11 15:46:49 +00:00
)
2024-03-07 21:57:28 +00:00
2024-03-13 01:00:14 +00:00
assert isinstance ( metadata , ModelMetadataWithFiles )
return HuggingFaceModels (
urls = metadata . ckpt_urls ,
is_diffusers = metadata . is_diffusers ,
)
2024-03-07 21:57:28 +00:00
2024-02-18 06:27:42 +00:00
@model_manager_router.patch (
2023-11-05 03:03:26 +00:00
" /i/ {key} " ,
operation_id = " update_model_record " ,
responses = {
2024-02-13 05:26:49 +00:00
200 : {
" description " : " The model was updated successfully " ,
2024-02-14 16:10:50 +00:00
" content " : { " application/json " : { " example " : example_model_config } } ,
2024-02-13 05:26:49 +00:00
} ,
2023-11-05 03:03:26 +00:00
400 : { " description " : " Bad request " } ,
404 : { " description " : " The model could not be found " } ,
409 : { " description " : " There is already a model corresponding to the new name " } ,
} ,
status_code = 200 ,
)
async def update_model_record (
2023-11-12 19:20:32 +00:00
key : Annotated [ str , Path ( description = " Unique key of model " ) ] ,
2024-03-05 01:04:27 +00:00
changes : Annotated [ ModelRecordChanges , Body ( description = " Model config " , example = example_model_input ) ] ,
2024-02-14 16:10:50 +00:00
) - > AnyModelConfig :
2024-03-05 01:04:27 +00:00
""" Update a model ' s config. """
2023-11-05 03:03:26 +00:00
logger = ApiDependencies . invoker . services . logger
2024-02-10 23:09:45 +00:00
record_store = ApiDependencies . invoker . services . model_manager . store
2024-03-22 01:14:45 +00:00
installer = ApiDependencies . invoker . services . model_manager . install
2023-11-05 03:03:26 +00:00
try :
2024-03-22 01:14:45 +00:00
record_store . update_model ( key , changes = changes )
2024-05-10 00:42:34 +00:00
config = installer . sync_model_path ( key )
config = add_cover_image_to_model_config ( config , ApiDependencies )
2023-11-13 21:06:35 +00:00
logger . info ( f " Updated model: { key } " )
2023-11-05 03:03:26 +00:00
except UnknownModelException as e :
raise HTTPException ( status_code = 404 , detail = str ( e ) )
except ValueError as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 409 , detail = str ( e ) )
2024-05-10 00:42:34 +00:00
return config
2023-11-05 03:42:44 +00:00
2024-03-06 03:57:05 +00:00
@model_manager_router.get (
" /i/ {key} /image " ,
operation_id = " get_model_image " ,
responses = {
200 : {
" description " : " The model image was fetched successfully " ,
} ,
400 : { " description " : " Bad request " } ,
2024-03-06 19:18:21 +00:00
404 : { " description " : " The model image could not be found " } ,
2024-03-06 03:57:05 +00:00
} ,
status_code = 200 ,
)
async def get_model_image (
key : str = Path ( description = " The name of model image file to get " ) ,
) - > FileResponse :
2024-03-06 18:53:05 +00:00
""" Gets an image file that previews the model """
2024-03-06 03:57:05 +00:00
try :
2024-03-06 18:15:33 +00:00
path = ApiDependencies . invoker . services . model_images . get_path ( key )
2024-03-06 03:57:05 +00:00
response = FileResponse (
path ,
media_type = " image/png " ,
filename = key + " .png " ,
content_disposition_type = " inline " ,
)
response . headers [ " Cache-Control " ] = f " max-age= { IMAGE_MAX_AGE } "
return response
except Exception :
raise HTTPException ( status_code = 404 )
2024-03-06 19:18:21 +00:00
2024-03-06 03:57:05 +00:00
@model_manager_router.patch (
" /i/ {key} /image " ,
operation_id = " update_model_image " ,
responses = {
200 : {
" description " : " The model image was updated successfully " ,
} ,
400 : { " description " : " Bad request " } ,
} ,
status_code = 200 ,
)
async def update_model_image (
key : Annotated [ str , Path ( description = " Unique key of model " ) ] ,
image : UploadFile ,
) - > None :
if not image . content_type or not image . content_type . startswith ( " image " ) :
raise HTTPException ( status_code = 415 , detail = " Not an image " )
contents = await image . read ( )
try :
pil_image = Image . open ( io . BytesIO ( contents ) )
except Exception :
ApiDependencies . invoker . services . logger . error ( traceback . format_exc ( ) )
raise HTTPException ( status_code = 415 , detail = " Failed to read image " )
logger = ApiDependencies . invoker . services . logger
model_images = ApiDependencies . invoker . services . model_images
try :
model_images . save ( pil_image , key )
logger . info ( f " Updated image for model: { key } " )
except ValueError as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 409 , detail = str ( e ) )
return
2024-02-18 06:27:42 +00:00
@model_manager_router.delete (
2023-11-05 03:42:44 +00:00
" /i/ {key} " ,
2024-03-05 05:32:16 +00:00
operation_id = " delete_model " ,
2023-11-13 23:12:45 +00:00
responses = {
204 : { " description " : " Model deleted successfully " } ,
404 : { " description " : " Model not found " } ,
} ,
2023-11-05 03:42:44 +00:00
status_code = 204 ,
)
2024-03-05 05:32:16 +00:00
async def delete_model (
2023-11-05 03:42:44 +00:00
key : str = Path ( description = " Unique key of model to remove from model registry. " ) ,
) - > Response :
2023-11-26 18:18:21 +00:00
"""
Delete model record from database .
The configuration record will be removed . The corresponding weights files will be
deleted as well if they reside within the InvokeAI " models " directory .
"""
2023-11-05 03:42:44 +00:00
logger = ApiDependencies . invoker . services . logger
try :
2024-02-10 23:09:45 +00:00
installer = ApiDependencies . invoker . services . model_manager . install
2023-11-26 18:18:21 +00:00
installer . delete ( key )
2023-11-05 03:42:44 +00:00
logger . info ( f " Deleted model: { key } " )
return Response ( status_code = 204 )
except UnknownModelException as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 404 , detail = str ( e ) )
2024-03-06 19:18:21 +00:00
2024-03-06 14:58:39 +00:00
@model_manager_router.delete (
" /i/ {key} /image " ,
operation_id = " delete_model_image " ,
responses = {
204 : { " description " : " Model image deleted successfully " } ,
404 : { " description " : " Model image not found " } ,
} ,
status_code = 204 ,
)
async def delete_model_image (
key : str = Path ( description = " Unique key of model image to remove from model_images directory. " ) ,
) - > None :
logger = ApiDependencies . invoker . services . logger
model_images = ApiDependencies . invoker . services . model_images
try :
model_images . delete ( key )
logger . info ( f " Deleted model image: { key } " )
return
except UnknownModelException as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 404 , detail = str ( e ) )
2023-11-05 03:42:44 +00:00
2024-02-18 06:27:42 +00:00
@model_manager_router.post (
2024-02-27 04:09:05 +00:00
" /install " ,
operation_id = " install_model " ,
2024-02-12 04:37:49 +00:00
responses = {
201 : { " description " : " The model imported successfully " } ,
415 : { " description " : " Unrecognized file/folder format " } ,
424 : { " description " : " The model appeared to import successfully, but could not be found in the model manager " } ,
409 : { " description " : " There is already a model corresponding to this path or repo_id " } ,
} ,
status_code = 201 ,
)
2024-02-27 04:09:05 +00:00
async def install_model (
source : str = Query ( description = " Model source to install, can be a local path, repo_id, or remote URL " ) ,
2024-02-29 20:13:38 +00:00
inplace : Optional [ bool ] = Query ( description = " Whether or not to install a local model in place " , default = False ) ,
2024-07-23 21:41:00 +00:00
access_token : Optional [ str ] = Query ( description = " access token for the remote resource " , default = None ) ,
config : ModelRecordChanges = Body (
description = " Object containing fields that override auto-probed values in the model config record, such as name, description and prediction_type " ,
2024-02-21 22:15:54 +00:00
example = { " name " : " string " , " description " : " string " } ,
2024-02-12 04:37:49 +00:00
) ,
) - > ModelInstallJob :
""" Install a model using a string identifier.
` source ` can be any of the following .
1. A path on the local filesystem ( ' C: \\ users \\ fred \\ model.safetensors ' )
2. A Url pointing to a single downloadable model file
3. A HuggingFace repo_id with any of the following formats :
- model / name
- model / name : fp16 : vae
- model / name : : vae - - use default precision
- model / name : fp16 : path / to / model . safetensors
- model / name : : path / to / model . safetensors
2024-07-23 21:41:00 +00:00
` config ` is a ModelRecordChanges object . Fields in this object will override
the ones that are probed automatically . Pass an empty object to accept
all the defaults .
2024-02-12 04:37:49 +00:00
` access_token ` is an optional access token for use with Urls that require
authentication .
Models will be downloaded , probed , configured and installed in a
series of background threads . The return object has ` status ` attribute
that can be used to monitor progress .
See the documentation for ` import_model_record ` for more information on
interpreting the job information returned by this route .
"""
logger = ApiDependencies . invoker . services . logger
try :
installer = ApiDependencies . invoker . services . model_manager . install
result : ModelInstallJob = installer . heuristic_import (
source = source ,
config = config ,
2024-02-27 04:09:05 +00:00
access_token = access_token ,
2024-02-29 20:13:38 +00:00
inplace = bool ( inplace ) ,
2023-11-26 02:45:59 +00:00
)
2023-11-26 18:18:21 +00:00
logger . info ( f " Started installation of { source } " )
2023-11-26 02:45:59 +00:00
except UnknownModelException as e :
logger . error ( str ( e ) )
2023-12-05 02:12:10 +00:00
raise HTTPException ( status_code = 424 , detail = str ( e ) )
2023-11-26 02:45:59 +00:00
except InvalidModelException as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 415 )
except ValueError as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 409 , detail = str ( e ) )
return result
2024-06-14 21:15:55 +00:00
2024-06-14 21:08:39 +00:00
@model_manager_router.get (
" /install/huggingface " ,
operation_id = " install_hugging_face_model " ,
responses = {
201 : { " description " : " The model is being installed " } ,
400 : { " description " : " Bad request " } ,
409 : { " description " : " There is already a model corresponding to this path or repo_id " } ,
} ,
status_code = 201 ,
2024-06-14 21:15:55 +00:00
response_class = HTMLResponse ,
2024-06-14 21:08:39 +00:00
)
async def install_hugging_face_model (
2024-06-17 00:51:08 +00:00
source : str = Query ( description = " HuggingFace repo_id to install " ) ,
2024-06-14 21:08:39 +00:00
) - > HTMLResponse :
""" Install a Hugging Face model using a string identifier. """
2024-06-17 00:51:08 +00:00
def generate_html ( title : str , heading : str , repo_id : str , is_error : bool , message : str | None = " " ) - > str :
if message :
message = f " <p> { message } </p> "
title_class = " error " if is_error else " success "
2024-06-14 21:08:39 +00:00
return f """
2024-06-17 00:51:08 +00:00
< html >
< head >
< title > { title } < / title >
< style >
body { {
text - align : center ;
background - color : hsl ( 220 12 % 10 % / 1 ) ;
font - family : Helvetica , sans - serif ;
color : hsl ( 220 12 % 86 % / 1 ) ;
} }
. repo - id { {
color : hsl ( 220 12 % 68 % / 1 ) ;
} }
. error { {
color : hsl ( 0 42 % 68 % / 1 )
} }
. message - box { {
display : inline - block ;
border - radius : 5 px ;
background - color : hsl ( 220 12 % 20 % / 1 ) ;
padding - inline - end : 30 px ;
padding : 20 px ;
padding - inline - start : 30 px ;
padding - inline - end : 30 px ;
} }
. container { {
display : flex ;
width : 100 % ;
height : 100 % ;
align - items : center ;
justify - content : center ;
} }
a { {
color : inherit
} }
a : visited { {
color : inherit
} }
a : active { {
color : inherit
} }
< / style >
< / head >
< body style = " background-color: hsl(220 12 % 10% / 1); " >
< div class = " container " >
< div class = " message-box " >
< h2 class = " {title_class} " > { heading } < / h2 >
{ message }
< p class = " repo-id " > Repo ID : { repo_id } < / p >
< / div >
< / div >
< / body >
< / html >
2024-06-14 21:08:39 +00:00
"""
try :
metadata = HuggingFaceMetadataFetch ( ) . from_id ( source )
assert isinstance ( metadata , ModelMetadataWithFiles )
except UnknownMetadataException :
2024-06-17 00:51:08 +00:00
title = " Unable to Install Model "
heading = " No HuggingFace repository found with that repo ID. "
message = " Ensure the repo ID is correct and try again. "
return HTMLResponse ( content = generate_html ( title , heading , source , True , message ) , status_code = 400 )
2024-06-14 21:08:39 +00:00
logger = ApiDependencies . invoker . services . logger
try :
installer = ApiDependencies . invoker . services . model_manager . install
if metadata . is_diffusers :
installer . heuristic_import (
source = source ,
inplace = False ,
)
elif metadata . ckpt_urls is not None and len ( metadata . ckpt_urls ) == 1 :
installer . heuristic_import (
source = str ( metadata . ckpt_urls [ 0 ] ) ,
inplace = False ,
)
else :
2024-06-17 00:51:08 +00:00
title = " Unable to Install Model "
heading = " This HuggingFace repo has multiple models. "
message = " Please use the Model Manager to install this model. "
return HTMLResponse ( content = generate_html ( title , heading , source , True , message ) , status_code = 200 )
title = " Model Install Started "
heading = " Your HuggingFace model is installing now. "
message = " You can close this tab and check the Model Manager for installation progress. "
return HTMLResponse ( content = generate_html ( title , heading , source , False , message ) , status_code = 201 )
2024-06-14 21:08:39 +00:00
except Exception as e :
logger . error ( str ( e ) )
2024-06-17 00:51:08 +00:00
title = " Unable to Install Model "
heading = " There was an problem installing this model. "
message = ' Please use the Model Manager directly to install this model. If the issue persists, ask for help on <a href= " https://discord.gg/ZmtBAhwWhy " >discord</a>. '
return HTMLResponse ( content = generate_html ( title , heading , source , True , message ) , status_code = 500 )
2024-06-14 21:08:39 +00:00
2023-11-26 22:13:31 +00:00
2024-02-18 06:27:42 +00:00
@model_manager_router.get (
2024-03-05 05:32:16 +00:00
" /install " ,
operation_id = " list_model_installs " ,
2023-11-26 02:45:59 +00:00
)
2024-03-05 05:32:16 +00:00
async def list_model_installs ( ) - > List [ ModelInstallJob ] :
2024-02-14 16:10:50 +00:00
""" Return the list of model install jobs.
Install jobs have a numeric ` id ` , a ` status ` , and other fields that provide information on
the nature of the job and its progress . The ` status ` is one of :
* " waiting " - - Job is waiting in the queue to run
* " downloading " - - Model file ( s ) are downloading
* " running " - - Model has downloaded and the model probing and registration process is running
* " completed " - - Installation completed successfully
* " error " - - An error occurred . Details will be in the " error_type " and " error " fields .
* " cancelled " - - Job was cancelled before completion .
Once completed , information about the model such as its size , base
2024-03-04 10:38:21 +00:00
model and type can be retrieved from the ` config_out ` field . For multi - file models such as diffusers ,
information on individual files can be retrieved from ` download_parts ` .
2024-02-14 16:10:50 +00:00
See the example and schema below for more information .
"""
2024-02-10 23:09:45 +00:00
jobs : List [ ModelInstallJob ] = ApiDependencies . invoker . services . model_manager . install . list_jobs ( )
2023-11-26 02:45:59 +00:00
return jobs
2023-11-26 18:18:21 +00:00
2023-11-26 22:13:31 +00:00
2024-02-18 06:27:42 +00:00
@model_manager_router.get (
2024-03-05 05:32:16 +00:00
" /install/ {id} " ,
2024-01-14 19:54:53 +00:00
operation_id = " get_model_install_job " ,
responses = {
200 : { " description " : " Success " } ,
404 : { " description " : " No such job " } ,
} ,
)
async def get_model_install_job ( id : int = Path ( description = " Model install id " ) ) - > ModelInstallJob :
2024-02-14 16:10:50 +00:00
"""
Return model install job corresponding to the given source . See the documentation for ' List Model Install Jobs '
for information on the format of the return value .
"""
2024-01-14 19:54:53 +00:00
try :
2024-02-10 23:09:45 +00:00
result : ModelInstallJob = ApiDependencies . invoker . services . model_manager . install . get_job_by_id ( id )
return result
2024-01-14 19:54:53 +00:00
except ValueError as e :
raise HTTPException ( status_code = 404 , detail = str ( e ) )
2024-02-18 06:27:42 +00:00
@model_manager_router.delete (
2024-03-05 05:32:16 +00:00
" /install/ {id} " ,
2024-01-14 19:54:53 +00:00
operation_id = " cancel_model_install_job " ,
responses = {
201 : { " description " : " The job was cancelled successfully " } ,
415 : { " description " : " No such job " } ,
} ,
status_code = 201 ,
)
async def cancel_model_install_job ( id : int = Path ( description = " Model install job ID " ) ) - > None :
""" Cancel the model install job(s) corresponding to the given job ID. """
2024-02-10 23:09:45 +00:00
installer = ApiDependencies . invoker . services . model_manager . install
2024-01-14 19:54:53 +00:00
try :
job = installer . get_job_by_id ( id )
except ValueError as e :
raise HTTPException ( status_code = 415 , detail = str ( e ) )
installer . cancel_job ( job )
2024-03-05 07:40:17 +00:00
@model_manager_router.delete (
" /install " ,
2023-11-26 18:18:21 +00:00
operation_id = " prune_model_install_jobs " ,
responses = {
204 : { " description " : " All completed and errored jobs have been pruned " } ,
400 : { " description " : " Bad request " } ,
} ,
)
2023-11-26 22:13:31 +00:00
async def prune_model_install_jobs ( ) - > Response :
2024-01-14 19:54:53 +00:00
""" Prune all completed and errored jobs from the install job list. """
2024-02-10 23:09:45 +00:00
ApiDependencies . invoker . services . model_manager . install . prune_jobs ( )
2023-11-26 18:18:21 +00:00
return Response ( status_code = 204 )
2023-11-26 22:13:31 +00:00
2024-03-05 07:40:17 +00:00
@model_manager_router.put (
" /convert/ {key} " ,
operation_id = " convert_model " ,
responses = {
200 : {
" description " : " Model converted successfully " ,
" content " : { " application/json " : { " example " : example_model_config } } ,
} ,
400 : { " description " : " Bad request " } ,
404 : { " description " : " Model not found " } ,
409 : { " description " : " There is already a model registered at this location " } ,
} ,
)
async def convert_model (
key : str = Path ( description = " Unique key of the safetensors main model to convert to diffusers format. " ) ,
) - > AnyModelConfig :
"""
Permanently convert a model into diffusers format , replacing the safetensors version .
Note that during the conversion process the key and model hash will change .
The return value is the model configuration for the converted model .
"""
model_manager = ApiDependencies . invoker . services . model_manager
2024-03-29 20:11:08 +00:00
loader = model_manager . load
2024-03-05 07:40:17 +00:00
logger = ApiDependencies . invoker . services . logger
store = ApiDependencies . invoker . services . model_manager . store
installer = ApiDependencies . invoker . services . model_manager . install
2024-03-05 05:32:16 +00:00
2024-03-05 07:40:17 +00:00
try :
model_config = store . get_model ( key )
except UnknownModelException as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 424 , detail = str ( e ) )
2024-03-05 05:32:16 +00:00
2024-03-05 07:40:17 +00:00
if not isinstance ( model_config , MainCheckpointConfig ) :
logger . error ( f " The model with key { key } is not a main checkpoint model. " )
raise HTTPException ( 400 , f " The model with key { key } is not a main checkpoint model. " )
2024-03-05 05:32:16 +00:00
2024-06-27 21:31:28 +00:00
with TemporaryDirectory ( dir = ApiDependencies . invoker . services . configuration . models_path ) as tmpdir :
convert_path = pathlib . Path ( tmpdir ) / pathlib . Path ( model_config . path ) . stem
converted_model = loader . load_model ( model_config )
# write the converted file to the convert path
raw_model = converted_model . model
assert hasattr ( raw_model , " save_pretrained " )
2024-07-23 21:41:00 +00:00
raw_model . save_pretrained ( convert_path ) # type: ignore
2024-06-27 21:31:28 +00:00
assert convert_path . exists ( )
# temporarily rename the original safetensors file so that there is no naming conflict
original_name = model_config . name
model_config . name = f " { original_name } .DELETE "
changes = ModelRecordChanges ( name = model_config . name )
store . update_model ( key , changes = changes )
# install the diffusers
try :
new_key = installer . install_path (
convert_path ,
2024-07-23 21:41:00 +00:00
config = ModelRecordChanges (
name = original_name ,
description = model_config . description ,
hash = model_config . hash ,
source = model_config . source ,
) ,
2024-06-27 21:31:28 +00:00
)
except Exception as e :
logger . error ( str ( e ) )
store . update_model ( key , changes = ModelRecordChanges ( name = original_name ) )
raise HTTPException ( status_code = 409 , detail = str ( e ) )
2024-03-05 05:32:16 +00:00
2024-05-10 01:17:35 +00:00
# Update the model image if the model had one
try :
model_image = ApiDependencies . invoker . services . model_images . get ( key )
ApiDependencies . invoker . services . model_images . save ( model_image , new_key )
ApiDependencies . invoker . services . model_images . delete ( key )
except ModelImageFileNotFoundException :
pass
2024-03-05 07:40:17 +00:00
# delete the original safetensors file
installer . delete ( key )
2024-03-05 05:32:16 +00:00
2024-06-27 21:31:28 +00:00
# delete the temporary directory
# shutil.rmtree(cache_path)
2024-03-05 05:32:16 +00:00
2024-03-05 07:40:17 +00:00
# return the config record for the new diffusers directory
2024-05-10 01:17:35 +00:00
new_config = store . get_model ( new_key )
new_config = add_cover_image_to_model_config ( new_config , ApiDependencies )
2024-03-05 07:40:17 +00:00
return new_config
2024-03-05 05:32:16 +00:00
2024-03-19 04:57:16 +00:00
@model_manager_router.get ( " /starter_models " , operation_id = " get_starter_models " , response_model = list [ StarterModel ] )
async def get_starter_models ( ) - > list [ StarterModel ] :
installed_models = ApiDependencies . invoker . services . model_manager . store . search_by_attr ( )
installed_model_sources = { m . source for m in installed_models }
starter_models = deepcopy ( STARTER_MODELS )
for model in starter_models :
if model . source in installed_model_sources :
model . is_installed = True
2024-03-19 22:31:31 +00:00
# Remove already-installed dependencies
2024-03-22 03:11:25 +00:00
missing_deps : list [ StarterModelWithoutDependencies ] = [ ]
2024-03-19 22:31:31 +00:00
for dep in model . dependencies or [ ] :
2024-03-22 03:11:25 +00:00
if dep . source not in installed_model_sources :
2024-03-19 22:31:31 +00:00
missing_deps . append ( dep )
model . dependencies = missing_deps
2024-03-19 04:57:16 +00:00
return starter_models