From 5b6dd47b9f0ecb2a6cf7097bf8903cb4ee5358ef Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 5 Jul 2023 15:13:21 -0400 Subject: [PATCH] add API for model convert --- invokeai/app/api/routers/models.py | 112 ++++++------------ .../backend/model_management/model_manager.py | 2 +- 2 files changed, 36 insertions(+), 78 deletions(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index a4360a3285..a3c0d1db50 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -9,9 +9,11 @@ from starlette.exceptions import HTTPException from invokeai.backend import BaseModelType, ModelType from invokeai.backend.model_management import AddModelResult -from invokeai.backend.model_management.models import (MODEL_CONFIGS, - OPENAPI_MODEL_CONFIGS, - SchedulerPredictionType) +from invokeai.backend.model_management.models import ( + MODEL_CONFIGS, + OPENAPI_MODEL_CONFIGS, + SchedulerPredictionType +) from ..dependencies import ApiDependencies @@ -25,11 +27,6 @@ class ImportModelResponse(BaseModel): location: str = Field(description="The path, repo_id or URL of the imported model") info: AddModelResult = Field(description="The model info") -class ConvertModelResponse(BaseModel): - name: str = Field(description="The name of the imported model") - info: AddModelResult = Field(description="The model info") - status: str = Field(description="The status of the API response") - class ModelsList(BaseModel): models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]] @@ -168,75 +165,36 @@ async def delete_model( logger.error(f"Model not found: {model_name}") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") - # @socketio.on("convertToDiffusers") - # def convert_to_diffusers(model_to_convert: dict): - # try: - # if model_info := self.generate.model_manager.model_info( - # model_name=model_to_convert["model_name"] - # ): - # if "weights" in model_info: - # ckpt_path = Path(model_info["weights"]) - # original_config_file = Path(model_info["config"]) - # model_name = model_to_convert["model_name"] - # model_description = model_info["description"] - # else: - # self.socketio.emit( - # "error", {"message": "Model is not a valid checkpoint file"} - # ) - # else: - # self.socketio.emit( - # "error", {"message": "Could not retrieve model info."} - # ) - - # if not ckpt_path.is_absolute(): - # ckpt_path = Path(Globals.root, ckpt_path) - - # if original_config_file and not original_config_file.is_absolute(): - # original_config_file = Path(Globals.root, original_config_file) - - # diffusers_path = Path( - # ckpt_path.parent.absolute(), f"{model_name}_diffusers" - # ) - - # if model_to_convert["save_location"] == "root": - # diffusers_path = Path( - # global_converted_ckpts_dir(), f"{model_name}_diffusers" - # ) - - # if ( - # model_to_convert["save_location"] == "custom" - # and model_to_convert["custom_location"] is not None - # ): - # diffusers_path = Path( - # model_to_convert["custom_location"], f"{model_name}_diffusers" - # ) - - # if diffusers_path.exists(): - # shutil.rmtree(diffusers_path) - - # self.generate.model_manager.convert_and_import( - # ckpt_path, - # diffusers_path, - # model_name=model_name, - # model_description=model_description, - # vae=None, - # original_config_file=original_config_file, - # commit_to_conf=opt.conf, - # ) - - # new_model_list = self.generate.model_manager.list_models() - # socketio.emit( - # "modelConverted", - # { - # "new_model_name": model_name, - # "model_list": new_model_list, - # "update": True, - # }, - # ) - # print(f">> Model Converted: {model_name}") - # except Exception as e: - # self.handle_exceptions(e) - +@models_router.patch( + "/convert/{base_model}/{model_type}/{model_name}", + operation_id="convert_model", + responses={ + 200: { "description": "Model converted successfully" }, + 400: {"description" : "Bad request" }, + 404: { "description": "Model not found" }, + }, + status_code = 200, + response_model = Union[tuple(MODEL_CONFIGS)], +) +async def convert_model( + base_model: BaseModelType = Path(description="Base model"), + model_type: ModelType = Path(description="The type of model"), + model_name: str = Path(description="model name"), +) -> Union[tuple(MODEL_CONFIGS)]: + """Convert a checkpoint model into a diffusers model""" + logger = ApiDependencies.invoker.services.logger + try: + logger.info(f"Converting model: {model_name}") + result = ApiDependencies.invoker.services.model_manager.convert_model(model_name, + base_model = base_model, + model_type = model_type + ) + except KeyError: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + return result.config + # @socketio.on("mergeDiffusersModels") # def merge_diffusers_models(model_merge_info: dict): # try: diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 28183a12e9..e8a4a0541c 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -651,7 +651,7 @@ class ModelManager(object): ) checkpoint_path = self.app_config.root_path / info["path"] old_diffusers_path = self.app_config.models_path / model.location - new_diffusers_path = self.app_config.models_path / base_model / model_type / model_name + new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name if new_diffusers_path.exists(): raise ValueError(f"A diffusers model already exists at {new_path}")