add API for model convert

This commit is contained in:
Lincoln Stein 2023-07-05 15:13:21 -04:00
parent 5027d0a603
commit 5b6dd47b9f
2 changed files with 36 additions and 78 deletions

View File

@ -9,9 +9,11 @@ from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import (MODEL_CONFIGS,
from invokeai.backend.model_management.models import (
MODEL_CONFIGS,
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType)
SchedulerPredictionType
)
from ..dependencies import ApiDependencies
@ -25,11 +27,6 @@ class ImportModelResponse(BaseModel):
location: str = Field(description="The path, repo_id or URL of the imported model")
info: AddModelResult = Field(description="The model info")
class ConvertModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ModelsList(BaseModel):
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
@ -168,74 +165,35 @@ async def delete_model(
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
@models_router.patch(
"/convert/{base_model}/{model_type}/{model_name}",
operation_id="convert_model",
responses={
200: { "description": "Model converted successfully" },
400: {"description" : "Bad request" },
404: { "description": "Model not found" },
},
status_code = 200,
response_model = Union[tuple(MODEL_CONFIGS)],
)
async def convert_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Union[tuple(MODEL_CONFIGS)]:
"""Convert a checkpoint model into a diffusers model"""
logger = ApiDependencies.invoker.services.logger
try:
logger.info(f"Converting model: {model_name}")
result = ApiDependencies.invoker.services.model_manager.convert_model(model_name,
base_model = base_model,
model_type = model_type
)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return result.config
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):

View File

@ -651,7 +651,7 @@ class ModelManager(object):
)
checkpoint_path = self.app_config.root_path / info["path"]
old_diffusers_path = self.app_config.models_path / model.location
new_diffusers_path = self.app_config.models_path / base_model / model_type / model_name
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
if new_diffusers_path.exists():
raise ValueError(f"A diffusers model already exists at {new_path}")