diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index f577e422d4..e6083a5bcb 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -16,7 +16,6 @@ class VaeRepo(BaseModel): subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") class ModelInfo(BaseModel): - model_name: str = Field(..., description="The name of the model") description: Optional[str] = Field(description="A description of the model") class CkptModelInfo(ModelInfo): @@ -58,6 +57,90 @@ async def list_models() -> ModelsList: models = parse_obj_as(ModelsList, { "models": models_raw }) return models +# Kent wrote the below code. It is highly suspect, and should be reviewed. Once all issues have been identified and fixed, this comment can be removed. +# Seriously, check this. I have no idea how to test it. + +""" Add Model """ +@models_router.post( + "/", + operation_id="add_model", + responses={201: {"model": Union[CkptModelInfo, DiffusersModelInfo], "new_model_list": ModelsList}}, +) +async def add_model( + model_info: Union[CkptModelInfo, DiffusersModelInfo], +) -> Union[CkptModelInfo, DiffusersModelInfo]: + """Adds a new model""" + try: + model_name = model_info["model_name"] + del model_info["model_name"] + model_attributes = model_info + + if len(model_attributes.get("vae", [])) == 0: + del model_attributes["vae"] + + ApiDependencies.invoker.services.model_manager.add_model( + model_name=model_name, + model_attributes=model_attributes, + clobber=True, + ) + # How does Ckpt support deprecation change the above? + + except Exception as e: + # Handle any exceptions thrown during the execution of the method + # or raise the exception to be handled by the global exception handler + raise HTTPException(status_code=500, detail=str(e)) + + return model_info + +""" Delete Model """ +@models_router.delete( + "/{model_name}", + operation_id="del_model", + responses={204: {"description": "Model deleted"}, 404: {"description": "Model not found"}}, +) +async def delete_model(model_name: str) -> None: + try: + # check if model exists + if model_name not in ApiDependencies.invoker.services.model_manager.models: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + + # delete model + print(f">> Deleting Model: {model_name}") + ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True) + print(f">> Model Deleted: {model_name}") + except Exception as e: + # Handle any exceptions thrown during the execution of the method + raise HTTPException(status_code=500, detail=str(e)) + + +""" Load Model """ +models_router.post( + "/load/{model_name}", + operation_id="load_model", + responses={200: {"model": Union[CkptModelInfo, DiffusersModelInfo]}, 404: {"description": "Model not found"}}, +) +async def load_model(model_name: str) -> Union[CkptModelInfo, DiffusersModelInfo]: + """ + Load an existing model by name + """ + try: + # check if model exists + if model_name not in ApiDependencies.invoker.services.model_manager.models: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + + # load model + model_info = ApiDependencies.invoker.services.model_manager.load_model(model_name) + print(f">> Model Loaded: {model_name}") + return model_info + + except Exception as e: + # Handle any exceptions thrown during the execution of the method + raise HTTPException(status_code=500, detail=str(e)) + + + + + # @socketio.on("requestSystemConfig") # def handle_request_capabilities(): # print(">> System config requested")