2023-07-14 15:14:33 +00:00
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654), 2023 Kent Keirsey (https://github.com/hipsterusername), 2023 Lincoln D. Stein
2023-07-06 03:13:01 +00:00
2023-03-15 05:15:53 +00:00
2023-07-14 15:14:33 +00:00
import pathlib
2023-07-06 17:15:15 +00:00
from typing import Literal , List , Optional , Union
2023-03-15 05:15:53 +00:00
2023-07-05 10:08:47 +00:00
from fastapi import Body , Path , Query , Response
from fastapi . routing import APIRouter
2023-07-06 03:13:01 +00:00
from pydantic import BaseModel , parse_obj_as
2023-07-05 10:08:47 +00:00
from starlette . exceptions import HTTPException
2023-06-11 03:12:21 +00:00
from invokeai . backend import BaseModelType , ModelType
2023-07-05 19:13:21 +00:00
from invokeai . backend . model_management . models import (
OPENAPI_MODEL_CONFIGS ,
2023-07-06 17:15:15 +00:00
SchedulerPredictionType ,
2023-07-16 18:17:05 +00:00
ModelNotFoundException ,
2023-07-05 19:13:21 +00:00
)
2023-07-06 17:15:15 +00:00
from invokeai . backend . model_management import MergeInterpolationMethod
2023-07-16 18:17:05 +00:00
2023-07-05 10:08:47 +00:00
from . . dependencies import ApiDependencies
2023-03-15 05:15:53 +00:00
models_router = APIRouter ( prefix = " /v1/models " , tags = [ " models " ] )
2023-07-06 03:13:01 +00:00
UpdateModelResponse = Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ]
ImportModelResponse = Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ]
ConvertModelResponse = Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ]
2023-07-06 17:15:15 +00:00
MergeModelResponse = Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ]
2023-07-15 03:03:18 +00:00
ImportModelAttributes = Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ]
2023-06-23 20:35:39 +00:00
2023-03-15 05:15:53 +00:00
class ModelsList ( BaseModel ) :
2023-07-04 21:26:57 +00:00
models : list [ Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ] ]
2023-03-15 05:15:53 +00:00
@models_router.get (
2023-07-06 05:34:50 +00:00
" / " ,
2023-03-15 05:15:53 +00:00
operation_id = " list_models " ,
responses = { 200 : { " model " : ModelsList } } ,
)
2023-05-16 03:44:08 +00:00
async def list_models (
2023-07-06 05:34:50 +00:00
base_model : Optional [ BaseModelType ] = Query ( default = None , description = " Base model " ) ,
model_type : Optional [ ModelType ] = Query ( default = None , description = " The type of model to get " ) ,
2023-05-16 03:44:08 +00:00
) - > ModelsList :
2023-03-15 05:15:53 +00:00
""" Gets a list of models """
2023-06-15 18:30:15 +00:00
models_raw = ApiDependencies . invoker . services . model_manager . list_models ( base_model , model_type )
2023-03-15 05:15:53 +00:00
models = parse_obj_as ( ModelsList , { " models " : models_raw } )
return models
2023-07-05 18:50:57 +00:00
@models_router.patch (
2023-07-04 21:26:57 +00:00
" / {base_model} / {model_type} / {model_name} " ,
2023-04-06 19:17:48 +00:00
operation_id = " update_model " ,
2023-07-05 18:50:57 +00:00
responses = { 200 : { " description " : " The model was updated successfully " } ,
2023-07-16 18:17:05 +00:00
400 : { " description " : " Bad request " } ,
2023-07-05 18:50:57 +00:00
404 : { " description " : " The model could not be found " } ,
2023-07-16 18:17:05 +00:00
409 : { " description " : " There is already a model corresponding to the new name " } ,
2023-07-05 18:50:57 +00:00
} ,
status_code = 200 ,
response_model = UpdateModelResponse ,
2023-04-06 19:17:48 +00:00
)
async def update_model (
2023-07-06 05:34:50 +00:00
base_model : BaseModelType = Path ( description = " Base model " ) ,
model_type : ModelType = Path ( description = " The type of model " ) ,
model_name : str = Path ( description = " model name " ) ,
info : Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ] = Body ( description = " Model configuration " ) ,
2023-07-05 18:50:57 +00:00
) - > UpdateModelResponse :
2023-07-16 18:17:05 +00:00
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies . invoker . services . logger
2023-07-17 14:00:28 +00:00
2023-07-05 18:50:57 +00:00
try :
2023-07-17 14:00:28 +00:00
previous_info = ApiDependencies . invoker . services . model_manager . list_model (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
)
2023-07-16 18:17:05 +00:00
# rename operation requested
if info . model_name != model_name or info . base_model != base_model :
2023-07-17 14:00:28 +00:00
ApiDependencies . invoker . services . model_manager . rename_model (
2023-07-16 18:17:05 +00:00
base_model = base_model ,
model_type = model_type ,
model_name = model_name ,
new_name = info . model_name ,
new_base = info . base_model ,
)
logger . info ( f ' Successfully renamed { base_model } / { model_name } => { info . base_model } / { info . model_name } ' )
2023-07-17 14:00:28 +00:00
# update information to support an update of attributes
2023-07-16 18:17:05 +00:00
model_name = info . model_name
base_model = info . base_model
2023-07-17 14:00:28 +00:00
new_info = ApiDependencies . invoker . services . model_manager . list_model (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
)
if new_info . get ( ' path ' ) != previous_info . get ( ' path ' ) : # model manager moved model path during rename - don't overwrite it
info . path = new_info . get ( ' path ' )
2023-07-16 18:17:05 +00:00
2023-07-05 18:50:57 +00:00
ApiDependencies . invoker . services . model_manager . update_model (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
model_attributes = info . dict ( )
)
2023-07-16 18:17:05 +00:00
2023-07-06 03:13:01 +00:00
model_raw = ApiDependencies . invoker . services . model_manager . list_model (
model_name = model_name ,
base_model = base_model ,
model_type = model_type ,
2023-07-05 18:50:57 +00:00
)
2023-07-06 03:13:01 +00:00
model_response = parse_obj_as ( UpdateModelResponse , model_raw )
2023-07-16 18:17:05 +00:00
except ModelNotFoundException as e :
2023-07-05 18:50:57 +00:00
raise HTTPException ( status_code = 404 , detail = str ( e ) )
except ValueError as e :
2023-07-16 18:17:05 +00:00
logger . error ( str ( e ) )
raise HTTPException ( status_code = 409 , detail = str ( e ) )
except Exception as e :
logger . error ( str ( e ) )
2023-07-05 18:50:57 +00:00
raise HTTPException ( status_code = 400 , detail = str ( e ) )
2023-04-07 02:25:18 +00:00
2023-04-06 19:17:48 +00:00
return model_response
2023-06-23 20:35:39 +00:00
@models_router.post (
2023-07-15 03:03:18 +00:00
" /import " ,
2023-06-23 20:35:39 +00:00
operation_id = " import_model " ,
2023-07-03 23:32:54 +00:00
responses = {
201 : { " description " : " The model imported successfully " } ,
404 : { " description " : " The model could not be found " } ,
2023-07-05 10:08:47 +00:00
424 : { " description " : " The model appeared to import successfully, but could not be found in the model manager " } ,
2023-07-04 13:59:11 +00:00
409 : { " description " : " There is already a model corresponding to this path or repo_id " } ,
2023-07-03 23:32:54 +00:00
} ,
status_code = 201 ,
response_model = ImportModelResponse
2023-06-23 20:35:39 +00:00
)
async def import_model (
2023-07-05 18:50:57 +00:00
location : str = Body ( description = " A model path, repo_id or URL to import " ) ,
2023-07-04 13:59:11 +00:00
prediction_type : Optional [ Literal [ ' v_prediction ' , ' epsilon ' , ' sample ' ] ] = \
Body ( description = ' Prediction type for SDv2 checkpoint files ' , default = " v_prediction " ) ,
2023-07-03 23:32:54 +00:00
) - > ImportModelResponse :
2023-07-15 03:03:18 +00:00
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
2023-07-04 13:59:11 +00:00
2023-07-05 18:50:57 +00:00
items_to_import = { location }
2023-06-23 20:35:39 +00:00
prediction_types = { x . value : x for x in SchedulerPredictionType }
logger = ApiDependencies . invoker . services . logger
2023-07-04 13:59:11 +00:00
try :
installed_models = ApiDependencies . invoker . services . model_manager . heuristic_import (
items_to_import = items_to_import ,
prediction_type_helper = lambda x : prediction_types . get ( prediction_type )
)
2023-07-05 18:50:57 +00:00
info = installed_models . get ( location )
2023-07-05 10:08:47 +00:00
if not info :
logger . error ( " Import failed " )
raise HTTPException ( status_code = 424 )
2023-07-05 18:50:57 +00:00
logger . info ( f ' Successfully imported { location } , got { info } ' )
2023-07-06 03:13:01 +00:00
model_raw = ApiDependencies . invoker . services . model_manager . list_model (
model_name = info . name ,
base_model = info . base_model ,
model_type = info . model_type
2023-07-03 23:32:54 +00:00
)
2023-07-06 03:13:01 +00:00
return parse_obj_as ( ImportModelResponse , model_raw )
2023-07-16 18:17:05 +00:00
except ModelNotFoundException as e :
2023-07-04 13:59:11 +00:00
logger . error ( str ( e ) )
raise HTTPException ( status_code = 404 , detail = str ( e ) )
except ValueError as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 409 , detail = str ( e ) )
2023-07-15 03:03:18 +00:00
@models_router.post (
" /add " ,
operation_id = " add_model " ,
responses = {
201 : { " description " : " The model added successfully " } ,
404 : { " description " : " The model could not be found " } ,
424 : { " description " : " The model appeared to add successfully, but could not be found in the model manager " } ,
409 : { " description " : " There is already a model corresponding to this path or repo_id " } ,
} ,
status_code = 201 ,
response_model = ImportModelResponse
)
async def add_model (
info : Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ] = Body ( description = " Model configuration " ) ,
) - > ImportModelResponse :
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path """
logger = ApiDependencies . invoker . services . logger
try :
ApiDependencies . invoker . services . model_manager . add_model (
info . model_name ,
info . base_model ,
info . model_type ,
model_attributes = info . dict ( )
)
logger . info ( f ' Successfully added { info . model_name } ' )
model_raw = ApiDependencies . invoker . services . model_manager . list_model (
model_name = info . model_name ,
base_model = info . base_model ,
model_type = info . model_type
)
return parse_obj_as ( ImportModelResponse , model_raw )
2023-07-16 18:17:05 +00:00
except ModelNotFoundException as e :
2023-07-15 03:03:18 +00:00
logger . error ( str ( e ) )
raise HTTPException ( status_code = 404 , detail = str ( e ) )
except ValueError as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 409 , detail = str ( e ) )
2023-04-06 19:17:48 +00:00
@models_router.delete (
2023-07-04 14:40:32 +00:00
" / {base_model} / {model_type} / {model_name} " ,
2023-04-06 19:17:48 +00:00
operation_id = " del_model " ,
responses = {
2023-07-14 17:45:16 +00:00
204 : { " description " : " Model deleted successfully " } ,
404 : { " description " : " Model not found " }
2023-04-06 19:17:48 +00:00
} ,
2023-07-14 17:45:16 +00:00
status_code = 204 ,
response_model = None ,
2023-04-06 19:17:48 +00:00
)
2023-07-04 14:40:32 +00:00
async def delete_model (
2023-07-05 10:08:47 +00:00
base_model : BaseModelType = Path ( description = " Base model " ) ,
model_type : ModelType = Path ( description = " The type of model " ) ,
model_name : str = Path ( description = " model name " ) ,
) - > Response :
2023-04-06 19:17:48 +00:00
""" Delete Model """
2023-04-29 14:48:50 +00:00
logger = ApiDependencies . invoker . services . logger
2023-04-06 19:17:48 +00:00
2023-07-04 14:40:32 +00:00
try :
ApiDependencies . invoker . services . model_manager . del_model ( model_name ,
base_model = base_model ,
model_type = model_type
)
logger . info ( f " Deleted model: { model_name } " )
2023-07-05 10:08:47 +00:00
return Response ( status_code = 204 )
2023-07-16 18:17:05 +00:00
except ModelNotFoundException as e :
logger . error ( str ( e ) )
raise HTTPException ( status_code = 404 , detail = str ( e ) )
2023-04-06 20:23:09 +00:00
2023-07-06 03:13:01 +00:00
@models_router.put (
2023-07-05 19:13:21 +00:00
" /convert/ {base_model} / {model_type} / {model_name} " ,
operation_id = " convert_model " ,
responses = {
200 : { " description " : " Model converted successfully " } ,
400 : { " description " : " Bad request " } ,
404 : { " description " : " Model not found " } ,
} ,
status_code = 200 ,
2023-07-06 17:15:15 +00:00
response_model = ConvertModelResponse ,
2023-07-05 19:13:21 +00:00
)
async def convert_model (
base_model : BaseModelType = Path ( description = " Base model " ) ,
model_type : ModelType = Path ( description = " The type of model " ) ,
model_name : str = Path ( description = " model name " ) ,
2023-07-14 17:45:16 +00:00
convert_dest_directory : Optional [ str ] = Query ( default = None , description = " Save the converted model to the designated directory " ) ,
2023-07-06 03:13:01 +00:00
) - > ConvertModelResponse :
2023-07-14 17:45:16 +00:00
""" Convert a checkpoint model into a diffusers model, optionally saving to the indicated destination directory, or `models` if none. """
2023-07-05 19:13:21 +00:00
logger = ApiDependencies . invoker . services . logger
try :
logger . info ( f " Converting model: { model_name } " )
2023-07-14 17:45:16 +00:00
dest = pathlib . Path ( convert_dest_directory ) if convert_dest_directory else None
2023-07-06 03:13:01 +00:00
ApiDependencies . invoker . services . model_manager . convert_model ( model_name ,
base_model = base_model ,
2023-07-14 17:45:16 +00:00
model_type = model_type ,
convert_dest_directory = dest ,
2023-07-06 03:13:01 +00:00
)
model_raw = ApiDependencies . invoker . services . model_manager . list_model ( model_name ,
2023-07-05 19:13:21 +00:00
base_model = base_model ,
2023-07-06 03:13:01 +00:00
model_type = model_type )
response = parse_obj_as ( ConvertModelResponse , model_raw )
2023-07-16 18:17:05 +00:00
except ModelNotFoundException as e :
raise HTTPException ( status_code = 404 , detail = f " Model ' { model_name } ' not found: { str ( e ) } " )
2023-07-05 19:13:21 +00:00
except ValueError as e :
raise HTTPException ( status_code = 400 , detail = str ( e ) )
2023-07-06 03:13:01 +00:00
return response
2023-07-14 15:14:33 +00:00
@models_router.get (
" /search " ,
operation_id = " search_for_models " ,
responses = {
200 : { " description " : " Directory searched successfully " } ,
404 : { " description " : " Invalid directory path " } ,
} ,
status_code = 200 ,
response_model = List [ pathlib . Path ]
)
async def search_for_models (
search_path : pathlib . Path = Query ( description = " Directory path to search for models " )
) - > List [ pathlib . Path ] :
if not search_path . is_dir ( ) :
raise HTTPException ( status_code = 404 , detail = f " The search path ' { search_path } ' does not exist or is not directory " )
return ApiDependencies . invoker . services . model_manager . search_for_models ( [ search_path ] )
2023-07-14 17:45:16 +00:00
@models_router.get (
" /ckpt_confs " ,
operation_id = " list_ckpt_configs " ,
responses = {
200 : { " description " : " paths retrieved successfully " } ,
} ,
status_code = 200 ,
response_model = List [ pathlib . Path ]
)
async def list_ckpt_configs (
) - > List [ pathlib . Path ] :
""" Return a list of the legacy checkpoint configuration files stored in `ROOT/configs/stable-diffusion`, relative to ROOT. """
return ApiDependencies . invoker . services . model_manager . list_checkpoint_configs ( )
@models_router.get (
" /sync " ,
operation_id = " sync_to_config " ,
responses = {
201 : { " description " : " synchronization successful " } ,
} ,
status_code = 201 ,
response_model = None
)
async def sync_to_config (
) - > None :
""" Call after making changes to models.yaml, autoimport directories or models directory to synchronize
in - memory data structures with disk data structures . """
return ApiDependencies . invoker . services . model_manager . sync_to_config ( )
2023-07-05 19:13:21 +00:00
2023-07-06 17:15:15 +00:00
@models_router.put (
" /merge/ {base_model} " ,
operation_id = " merge_models " ,
responses = {
200 : { " description " : " Model converted successfully " } ,
400 : { " description " : " Incompatible models " } ,
404 : { " description " : " One or more models not found " } ,
} ,
status_code = 200 ,
response_model = MergeModelResponse ,
)
async def merge_models (
2023-07-06 19:12:34 +00:00
base_model : BaseModelType = Path ( description = " Base model " ) ,
model_names : List [ str ] = Body ( description = " model name " , min_items = 2 , max_items = 3 ) ,
merged_model_name : Optional [ str ] = Body ( description = " Name of destination model " ) ,
alpha : Optional [ float ] = Body ( description = " Alpha weighting strength to apply to 2d and 3d models " , default = 0.5 ) ,
interp : Optional [ MergeInterpolationMethod ] = Body ( description = " Interpolation method " ) ,
force : Optional [ bool ] = Body ( description = " Force merging of models created with different versions of diffusers " , default = False ) ,
2023-07-14 17:45:16 +00:00
merge_dest_directory : Optional [ str ] = Body ( description = " Save the merged model to the designated directory (with ' merged_model_name ' appended) " , default = None )
2023-07-06 17:15:15 +00:00
) - > MergeModelResponse :
""" Convert a checkpoint model into a diffusers model """
logger = ApiDependencies . invoker . services . logger
try :
2023-07-14 17:45:16 +00:00
logger . info ( f " Merging models: { model_names } into { merge_dest_directory or ' <MODELS> ' } / { merged_model_name } " )
dest = pathlib . Path ( merge_dest_directory ) if merge_dest_directory else None
2023-07-06 17:15:15 +00:00
result = ApiDependencies . invoker . services . model_manager . merge_models ( model_names ,
base_model ,
2023-07-14 17:45:16 +00:00
merged_model_name = merged_model_name or " + " . join ( model_names ) ,
alpha = alpha ,
interp = interp ,
force = force ,
merge_dest_directory = dest
)
2023-07-06 17:15:15 +00:00
model_raw = ApiDependencies . invoker . services . model_manager . list_model ( result . name ,
base_model = base_model ,
model_type = ModelType . Main ,
)
response = parse_obj_as ( ConvertModelResponse , model_raw )
2023-07-16 18:17:05 +00:00
except ModelNotFoundException :
2023-07-06 17:15:15 +00:00
raise HTTPException ( status_code = 404 , detail = f " One or more of the models ' { model_names } ' not found " )
except ValueError as e :
raise HTTPException ( status_code = 400 , detail = str ( e ) )
return response
2023-07-16 18:17:05 +00:00
# The rename operation is now supported by update_model and no longer needs to be
# a standalone route.
# @models_router.post(
# "/rename/{base_model}/{model_type}/{model_name}",
# operation_id="rename_model",
# responses= {
# 201: {"description" : "The model was renamed successfully"},
# 404: {"description" : "The model could not be found"},
# 409: {"description" : "There is already a model corresponding to the new name"},
# },
# status_code=201,
# response_model=ImportModelResponse
# )
# async def rename_model(
# base_model: BaseModelType = Path(description="Base model"),
# model_type: ModelType = Path(description="The type of model"),
# model_name: str = Path(description="current model name"),
# new_name: Optional[str] = Query(description="new model name", default=None),
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
# ) -> ImportModelResponse:
# """ Rename a model"""
# logger = ApiDependencies.invoker.services.logger
# try:
# result = ApiDependencies.invoker.services.model_manager.rename_model(
# base_model = base_model,
# model_type = model_type,
# model_name = model_name,
# new_name = new_name,
# new_base = new_base,
# )
# logger.debug(result)
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
# model_name=new_name or model_name,
# base_model=new_base or base_model,
# model_type=model_type
# )
# return parse_obj_as(ImportModelResponse, model_raw)
# except ModelNotFoundException as e:
# logger.error(str(e))
# raise HTTPException(status_code=404, detail=str(e))
# except ValueError as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))