diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index c988d89596..77b9b947f8 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -17,7 +17,6 @@ class VaeRepo(BaseModel): class ModelInfo(BaseModel): - model_name: str = Field(..., description="The name of the model") description: Optional[str] = Field(description="A description of the model") @@ -38,12 +37,13 @@ class DiffusersModelInfo(ModelInfo): repo_id: Optional[str] = Field(description="The repo ID to use for this model") path: Optional[str] = Field(description="The path to the model") +class CreateModelRequest (BaseModel): + name: str = Field(description="The name of the model") + info: Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")] = Field(description="The model info") class ModelsList(BaseModel): models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] - - @models_router.get( "/", operation_id="list_models", @@ -55,30 +55,28 @@ async def list_models() -> ModelsList: models = parse_obj_as(ModelsList, { "models": models_raw }) return models -# Kent wrote the below code. It is highly suspect, and should be reviewed. Once all issues have been identified and fixed, this comment can be removed. -# Seriously, check this. I have no idea how to test it. - -""" Add Model """ +#Update Model @models_router.post( "/", - operation_id="add_model", - responses={201: {"model": Union[CkptModelInfo, DiffusersModelInfo], "new_model_list": ModelsList}}, + operation_id="update_model", + responses={ + 201: { + "model": Union[CkptModelInfo, DiffusersModelInfo], + "new_model_list": ModelsList + }, + 202: { + "description": "Model submission is processing. Check back later." + }, + }, ) -async def add_model( - model_info: Union[CkptModelInfo, DiffusersModelInfo], -) -> Union[CkptModelInfo, DiffusersModelInfo]: - """Adds a new model""" +async def update_model( + model_request: CreateModelRequest +) -> CreateModelRequest: + #Adds a new model try: - model_name = model_info["model_name"] - del model_info["model_name"] - model_attributes = model_info - - if len(model_attributes.get("vae", [])) == 0: - del model_attributes["vae"] - ApiDependencies.invoker.services.model_manager.add_model( - model_name=model_name, - model_attributes=model_attributes, + model_name=model_request["name"], + model_attributes=model_request["info"], clobber=True, ) # How does Ckpt support deprecation change the above? @@ -88,7 +86,7 @@ async def add_model( # or raise the exception to be handled by the global exception handler raise HTTPException(status_code=500, detail=str(e)) - return model_info + return model_request """ Delete Model """ @models_router.delete( @@ -111,34 +109,6 @@ async def delete_model(model_name: str) -> None: raise HTTPException(status_code=500, detail=str(e)) -""" Load Model """ -models_router.post( - "/load/{model_name}", - operation_id="load_model", - responses={200: {"model": Union[CkptModelInfo, DiffusersModelInfo]}, 404: {"description": "Model not found"}}, -) -async def load_model(model_name: str) -> Union[CkptModelInfo, DiffusersModelInfo]: - """ - Load an existing model by name - """ - try: - # check if model exists - if model_name not in ApiDependencies.invoker.services.model_manager.models: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") - - # load model - model_info = ApiDependencies.invoker.services.model_manager.load_model(model_name) - print(f">> Model Loaded: {model_name}") - return model_info - - except Exception as e: - # Handle any exceptions thrown during the execution of the method - raise HTTPException(status_code=500, detail=str(e)) - - - - - # @socketio.on("requestSystemConfig") # def handle_request_capabilities(): # print(">> System config requested")