From 9e3cd33a9959332360278c6806f3f3c49d5d0d70 Mon Sep 17 00:00:00 2001 From: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Date: Sat, 25 Mar 2023 21:51:38 -0400 Subject: [PATCH] Adds (Untested) Add Delete Load Endpoints --- invokeai/app/api/routers/models.py | 87 +++++++++++++++++++++++++++++- 1 file changed, 86 insertions(+), 1 deletion(-) diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 5b3fbebddd..c988d89596 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -2,7 +2,7 @@ from typing import Annotated, Any, List, Literal, Optional, Union -from fastapi.routing import APIRouter +from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from ..dependencies import ApiDependencies @@ -17,6 +17,7 @@ class VaeRepo(BaseModel): class ModelInfo(BaseModel): + model_name: str = Field(..., description="The name of the model") description: Optional[str] = Field(description="A description of the model") @@ -54,6 +55,90 @@ async def list_models() -> ModelsList: models = parse_obj_as(ModelsList, { "models": models_raw }) return models +# Kent wrote the below code. It is highly suspect, and should be reviewed. Once all issues have been identified and fixed, this comment can be removed. +# Seriously, check this. I have no idea how to test it. + +""" Add Model """ +@models_router.post( + "/", + operation_id="add_model", + responses={201: {"model": Union[CkptModelInfo, DiffusersModelInfo], "new_model_list": ModelsList}}, +) +async def add_model( + model_info: Union[CkptModelInfo, DiffusersModelInfo], +) -> Union[CkptModelInfo, DiffusersModelInfo]: + """Adds a new model""" + try: + model_name = model_info["model_name"] + del model_info["model_name"] + model_attributes = model_info + + if len(model_attributes.get("vae", [])) == 0: + del model_attributes["vae"] + + ApiDependencies.invoker.services.model_manager.add_model( + model_name=model_name, + model_attributes=model_attributes, + clobber=True, + ) + # How does Ckpt support deprecation change the above? + + except Exception as e: + # Handle any exceptions thrown during the execution of the method + # or raise the exception to be handled by the global exception handler + raise HTTPException(status_code=500, detail=str(e)) + + return model_info + +""" Delete Model """ +@models_router.delete( + "/{model_name}", + operation_id="del_model", + responses={204: {"description": "Model deleted"}, 404: {"description": "Model not found"}}, +) +async def delete_model(model_name: str) -> None: + try: + # check if model exists + if model_name not in ApiDependencies.invoker.services.model_manager.models: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + + # delete model + print(f">> Deleting Model: {model_name}") + ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True) + print(f">> Model Deleted: {model_name}") + except Exception as e: + # Handle any exceptions thrown during the execution of the method + raise HTTPException(status_code=500, detail=str(e)) + + +""" Load Model """ +models_router.post( + "/load/{model_name}", + operation_id="load_model", + responses={200: {"model": Union[CkptModelInfo, DiffusersModelInfo]}, 404: {"description": "Model not found"}}, +) +async def load_model(model_name: str) -> Union[CkptModelInfo, DiffusersModelInfo]: + """ + Load an existing model by name + """ + try: + # check if model exists + if model_name not in ApiDependencies.invoker.services.model_manager.models: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + + # load model + model_info = ApiDependencies.invoker.services.model_manager.load_model(model_name) + print(f">> Model Loaded: {model_name}") + return model_info + + except Exception as e: + # Handle any exceptions thrown during the execution of the method + raise HTTPException(status_code=500, detail=str(e)) + + + + + # @socketio.on("requestSystemConfig") # def handle_request_capabilities(): # print(">> System config requested")