diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 07ecb45003..95a0f2817c 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -133,3 +133,207 @@ async def delete_model(model_name: str) -> None: except Exception as e: raise HTTPException(status_code=500, detail=str(e)) +@models_router.post( + "/{model_to_convert}", + operation_id="convert_model", + responses={ + 201: { + "model_response": "Model converted successfully.", + }, + 202: { + "description": "Model conversion is processing. Check back later." + }, + }, +) +async def convert_model(convert_request = ConvertedModelRequest) -> ConvertedModelResponse: + """ Convert Model """ + try: + convert_request_info = convert_request.info + info_dict = convert_request_info.dict() + convert_request = ConvertedModelRequest(name=convert_request.name, config=info_dict.config, weights=info_dict.weights, description=info_dict.description) + + if model_info := ApiDependencies.invoker.services.model_manager.model_info( + model_name=convert_request.name + ): + if "weights" in model_info: + ckpt_path = Path(convert_request.weights) + original_config_file = Path(convert_request.config) + model_name = convert_request.weights + model_description = convert_request.description + else: + raise HTTPException(status_code=404, detail=f"Model '{convert_request.name}' is not a valid checkpoint model") + else: + raise HTTPException(status_code=404, detail=f"Unable to retrieve model info") + + if not ckpt_path.is_absolute(): + ckpt_path = Path(Globals.root, ckpt_path) + + if original_config_file and not original_config_file.is_absolute(): + original_config_file = Path(Globals.root, original_config_file) + + diffusers_path = Path( + ckpt_path.parent.absolute(), f"{model_name}_diffusers" + ) + + if model_to_convert["save_location"] == "root": + diffusers_path = Path( + global_converted_ckpts_dir(), f"{model_name}_diffusers" + ) + + if ( + model_to_convert["save_location"] == "custom" + and model_to_convert["custom_location"] is not None + ): + diffusers_path = Path( + model_to_convert["custom_location"], f"{model_name}_diffusers" + ) + + if diffusers_path.exists(): + shutil.rmtree(diffusers_path) + + self.generate.model_manager.convert_and_import( + ckpt_path, + diffusers_path, + model_name=model_name, + model_description=model_description, + vae=None, + original_config_file=original_config_file, + commit_to_conf=opt.conf, + ) + + new_model_list = self.generate.model_manager.list_models() + socketio.emit( + "modelConverted", + { + "new_model_name": model_name, + "model_list": new_model_list, + "update": True, + }, + ) + print(f">> Model Converted: {model_name}") + + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + + # @socketio.on("convertToDiffusers") + # def convert_to_diffusers(model_to_convert: dict): + # try: + # if model_info := self.generate.model_manager.model_info( + # model_name=model_to_convert["model_name"] + # ): + # if "weights" in model_info: + # ckpt_path = Path(model_info["weights"]) + # original_config_file = Path(model_info["config"]) + # model_name = model_to_convert["model_name"] + # model_description = model_info["description"] + # else: + # self.socketio.emit( + # "error", {"message": "Model is not a valid checkpoint file"} + # ) + # else: + # self.socketio.emit( + # "error", {"message": "Could not retrieve model info."} + # ) + + # if not ckpt_path.is_absolute(): + # ckpt_path = Path(Globals.root, ckpt_path) + + # if original_config_file and not original_config_file.is_absolute(): + # original_config_file = Path(Globals.root, original_config_file) + + # diffusers_path = Path( + # ckpt_path.parent.absolute(), f"{model_name}_diffusers" + # ) + + # if model_to_convert["save_location"] == "root": + # diffusers_path = Path( + # global_converted_ckpts_dir(), f"{model_name}_diffusers" + # ) + + # if ( + # model_to_convert["save_location"] == "custom" + # and model_to_convert["custom_location"] is not None + # ): + # diffusers_path = Path( + # model_to_convert["custom_location"], f"{model_name}_diffusers" + # ) + + # if diffusers_path.exists(): + # shutil.rmtree(diffusers_path) + + # self.generate.model_manager.convert_and_import( + # ckpt_path, + # diffusers_path, + # model_name=model_name, + # model_description=model_description, + # vae=None, + # original_config_file=original_config_file, + # commit_to_conf=opt.conf, + # ) + + # new_model_list = self.generate.model_manager.list_models() + # socketio.emit( + # "modelConverted", + # { + # "new_model_name": model_name, + # "model_list": new_model_list, + # "update": True, + # }, + # ) + # print(f">> Model Converted: {model_name}") + # except Exception as e: + # self.handle_exceptions(e) + + # @socketio.on("mergeDiffusersModels") + # def merge_diffusers_models(model_merge_info: dict): + # try: + # models_to_merge = model_merge_info["models_to_merge"] + # model_ids_or_paths = [ + # self.generate.model_manager.model_name_or_path(x) + # for x in models_to_merge + # ] + # merged_pipe = merge_diffusion_models( + # model_ids_or_paths, + # model_merge_info["alpha"], + # model_merge_info["interp"], + # model_merge_info["force"], + # ) + + # dump_path = global_models_dir() / "merged_models" + # if model_merge_info["model_merge_save_path"] is not None: + # dump_path = Path(model_merge_info["model_merge_save_path"]) + + # os.makedirs(dump_path, exist_ok=True) + # dump_path = dump_path / model_merge_info["merged_model_name"] + # merged_pipe.save_pretrained(dump_path, safe_serialization=1) + + # merged_model_config = dict( + # model_name=model_merge_info["merged_model_name"], + # description=f'Merge of models {", ".join(models_to_merge)}', + # commit_to_conf=opt.conf, + # ) + + # if vae := self.generate.model_manager.config[models_to_merge[0]].get( + # "vae", None + # ): + # print(f">> Using configured VAE assigned to {models_to_merge[0]}") + # merged_model_config.update(vae=vae) + + # self.generate.model_manager.import_diffuser_model( + # dump_path, **merged_model_config + # ) + # new_model_list = self.generate.model_manager.list_models() + + # socketio.emit( + # "modelsMerged", + # { + # "merged_models": models_to_merge, + # "merged_model_name": model_merge_info["merged_model_name"], + # "model_list": new_model_list, + # "update": True, + # }, + # ) + # print(f">> Models Merged: {models_to_merge}") + # print(f">> New Model Added: {model_merge_info['merged_model_name']}") + # except Exception as e: \ No newline at end of file