2023-04-07 02:25:18 +00:00
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
2023-03-15 05:15:53 +00:00
2023-06-23 20:35:39 +00:00
from typing import Literal , Optional , Union
2023-03-15 05:15:53 +00:00
2023-07-03 23:32:54 +00:00
from fastapi import Query , Body
2023-04-06 19:17:48 +00:00
from fastapi . routing import APIRouter , HTTPException
2023-03-15 05:15:53 +00:00
from pydantic import BaseModel , Field , parse_obj_as
from . . dependencies import ApiDependencies
2023-06-11 03:12:21 +00:00
from invokeai . backend import BaseModelType , ModelType
2023-07-03 23:32:54 +00:00
from invokeai . backend . model_management import AddModelResult
2023-06-23 20:35:39 +00:00
from invokeai . backend . model_management . models import OPENAPI_MODEL_CONFIGS , SchedulerPredictionType
2023-06-17 14:15:36 +00:00
MODEL_CONFIGS = Union [ tuple ( OPENAPI_MODEL_CONFIGS ) ]
2023-03-15 05:15:53 +00:00
models_router = APIRouter ( prefix = " /v1/models " , tags = [ " models " ] )
class VaeRepo ( BaseModel ) :
repo_id : str = Field ( description = " The repo ID to use for this VAE " )
path : Optional [ str ] = Field ( description = " The path to the VAE " )
subfolder : Optional [ str ] = Field ( description = " The subfolder to use for this VAE " )
class ModelInfo ( BaseModel ) :
description : Optional [ str ] = Field ( description = " A description of the model " )
2023-05-13 18:44:44 +00:00
model_name : str = Field ( description = " The name of the model " )
model_type : str = Field ( description = " The type of the model " )
class DiffusersModelInfo ( ModelInfo ) :
format : Literal [ ' folder ' ] = ' folder '
vae : Optional [ VaeRepo ] = Field ( description = " The VAE repo to use for this model " )
repo_id : Optional [ str ] = Field ( description = " The repo ID to use for this model " )
path : Optional [ str ] = Field ( description = " The path to the model " )
2023-03-15 05:15:53 +00:00
class CkptModelInfo ( ModelInfo ) :
format : Literal [ ' ckpt ' ] = ' ckpt '
config : str = Field ( description = " The path to the model config " )
weights : str = Field ( description = " The path to the model weights " )
vae : str = Field ( description = " The path to the model VAE " )
width : Optional [ int ] = Field ( description = " The width of the model " )
height : Optional [ int ] = Field ( description = " The height of the model " )
2023-05-13 18:44:44 +00:00
class SafetensorsModelInfo ( CkptModelInfo ) :
format : Literal [ ' safetensors ' ] = ' safetensors '
2023-03-15 05:15:53 +00:00
2023-04-07 02:25:18 +00:00
class CreateModelRequest ( BaseModel ) :
2023-04-06 19:17:48 +00:00
name : str = Field ( description = " The name of the model " )
2023-04-07 02:25:18 +00:00
info : Union [ CkptModelInfo , DiffusersModelInfo ] = Field ( discriminator = " format " , description = " The model info " )
2023-04-06 19:17:48 +00:00
2023-04-07 02:25:18 +00:00
class CreateModelResponse ( BaseModel ) :
2023-04-06 19:17:48 +00:00
name : str = Field ( description = " The name of the new model " )
2023-04-07 02:25:18 +00:00
info : Union [ CkptModelInfo , DiffusersModelInfo ] = Field ( discriminator = " format " , description = " The model info " )
2023-04-06 19:17:48 +00:00
status : str = Field ( description = " The status of the API response " )
2023-07-03 23:32:54 +00:00
class ImportModelResponse ( BaseModel ) :
name : str = Field ( description = " The name of the imported model " )
# base_model: str = Field(description="The base model")
# model_type: str = Field(description="The model type")
info : AddModelResult = Field ( description = " The model info " )
status : str = Field ( description = " The status of the API response " )
2023-06-23 20:35:39 +00:00
2023-04-07 02:25:18 +00:00
class ConversionRequest ( BaseModel ) :
2023-04-06 19:17:48 +00:00
name : str = Field ( description = " The name of the new model " )
info : CkptModelInfo = Field ( description = " The converted model info " )
2023-04-07 02:25:18 +00:00
save_location : str = Field ( description = " The path to save the converted model weights " )
2023-04-06 19:17:48 +00:00
2023-04-07 02:25:18 +00:00
class ConvertedModelResponse ( BaseModel ) :
2023-04-06 19:17:48 +00:00
name : str = Field ( description = " The name of the new model " )
info : DiffusersModelInfo = Field ( description = " The converted model info " )
2023-03-15 05:15:53 +00:00
class ModelsList ( BaseModel ) :
2023-06-22 07:34:12 +00:00
models : list [ MODEL_CONFIGS ]
2023-03-15 05:15:53 +00:00
@models_router.get (
" / " ,
operation_id = " list_models " ,
responses = { 200 : { " model " : ModelsList } } ,
)
2023-05-16 03:44:08 +00:00
async def list_models (
2023-06-22 07:34:12 +00:00
base_model : Optional [ BaseModelType ] = Query (
2023-06-11 03:12:21 +00:00
default = None , description = " Base model "
) ,
2023-06-22 07:34:12 +00:00
model_type : Optional [ ModelType ] = Query (
2023-06-11 03:12:21 +00:00
default = None , description = " The type of model to get "
) ,
2023-05-16 03:44:08 +00:00
) - > ModelsList :
2023-03-15 05:15:53 +00:00
""" Gets a list of models """
2023-06-15 18:30:15 +00:00
models_raw = ApiDependencies . invoker . services . model_manager . list_models ( base_model , model_type )
2023-03-15 05:15:53 +00:00
models = parse_obj_as ( ModelsList , { " models " : models_raw } )
return models
2023-04-06 19:17:48 +00:00
@models_router.post (
" / " ,
operation_id = " update_model " ,
2023-04-07 02:25:18 +00:00
responses = { 200 : { " status " : " success " } } ,
2023-04-06 19:17:48 +00:00
)
async def update_model (
model_request : CreateModelRequest
) - > CreateModelResponse :
""" Add Model """
2023-04-07 02:25:18 +00:00
model_request_info = model_request . info
info_dict = model_request_info . dict ( )
model_response = CreateModelResponse ( name = model_request . name , info = model_request . info , status = " success " )
ApiDependencies . invoker . services . model_manager . add_model (
model_name = model_request . name ,
model_attributes = info_dict ,
clobber = True ,
)
2023-04-06 19:17:48 +00:00
return model_response
2023-06-23 20:35:39 +00:00
@models_router.post (
2023-07-03 23:32:54 +00:00
" /import " ,
2023-06-23 20:35:39 +00:00
operation_id = " import_model " ,
2023-07-03 23:32:54 +00:00
responses = {
201 : { " description " : " The model imported successfully " } ,
404 : { " description " : " The model could not be found " } ,
} ,
status_code = 201 ,
response_model = ImportModelResponse
2023-06-23 20:35:39 +00:00
)
async def import_model (
2023-07-03 23:32:54 +00:00
name : str = Query ( description = " A model path, repo_id or URL to import " ) ,
prediction_type : Optional [ Literal [ ' v_prediction ' , ' epsilon ' , ' sample ' ] ] = Query ( description = ' Prediction type for SDv2 checkpoint files ' , default = " v_prediction " ) ,
) - > ImportModelResponse :
""" Add a model using its local path, repo_id, or remote URL """
items_to_import = { name }
2023-06-23 20:35:39 +00:00
prediction_types = { x . value : x for x in SchedulerPredictionType }
logger = ApiDependencies . invoker . services . logger
installed_models = ApiDependencies . invoker . services . model_manager . heuristic_import (
items_to_import = items_to_import ,
2023-07-03 23:32:54 +00:00
prediction_type_helper = lambda x : prediction_types . get ( prediction_type )
2023-06-23 20:35:39 +00:00
)
2023-07-03 23:32:54 +00:00
if info := installed_models . get ( name ) :
logger . info ( f ' Successfully imported { name } , got { info } ' )
return ImportModelResponse (
name = name ,
info = info ,
status = " success " ,
)
2023-06-23 20:35:39 +00:00
else :
2023-07-03 23:32:54 +00:00
logger . error ( f ' Model { name } not imported ' )
raise HTTPException ( status_code = 404 , detail = f ' Model { name } not found ' )
2023-04-06 19:17:48 +00:00
@models_router.delete (
" / {model_name} " ,
operation_id = " del_model " ,
responses = {
204 : {
2023-04-07 02:25:18 +00:00
" description " : " Model deleted successfully "
2023-04-06 19:17:48 +00:00
} ,
404 : {
" description " : " Model not found "
}
} ,
)
async def delete_model ( model_name : str ) - > None :
""" Delete Model """
model_names = ApiDependencies . invoker . services . model_manager . model_names ( )
2023-04-29 14:48:50 +00:00
logger = ApiDependencies . invoker . services . logger
2023-04-06 19:17:48 +00:00
model_exists = model_name in model_names
2023-04-07 02:25:18 +00:00
# check if model exists
2023-04-29 13:43:40 +00:00
logger . info ( f " Checking for model { model_name } ... " )
2023-04-07 02:25:18 +00:00
if model_exists :
2023-04-29 13:43:40 +00:00
logger . info ( f " Deleting Model: { model_name } " )
2023-04-06 19:17:48 +00:00
ApiDependencies . invoker . services . model_manager . del_model ( model_name , delete_files = True )
2023-04-29 13:43:40 +00:00
logger . info ( f " Model Deleted: { model_name } " )
2023-04-07 02:26:28 +00:00
raise HTTPException ( status_code = 204 , detail = f " Model ' { model_name } ' deleted successfully " )
2023-04-06 19:17:48 +00:00
2023-04-08 02:25:30 +00:00
else :
2023-05-13 18:44:44 +00:00
logger . error ( " Model not found " )
2023-04-07 02:25:18 +00:00
raise HTTPException ( status_code = 404 , detail = f " Model ' { model_name } ' not found " )
2023-04-06 20:23:09 +00:00
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
2023-04-19 00:49:00 +00:00
# except Exception as e: