diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index 95a0f2817c..aaeb9517d4 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -1,11 +1,17 @@ -# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) +# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) +import shutil +import asyncio from typing import Annotated, Any, List, Literal, Optional, Union from fastapi.routing import APIRouter, HTTPException from pydantic import BaseModel, Field, parse_obj_as from pathlib import Path from ..dependencies import ApiDependencies +from invokeai.backend.globals import Globals, global_converted_ckpts_dir +from invokeai.backend.args import Args + + models_router = APIRouter(prefix="/v1/models", tags=["models"]) @@ -34,20 +40,22 @@ class DiffusersModelInfo(ModelInfo): repo_id: Optional[str] = Field(description="The repo ID to use for this model") path: Optional[str] = Field(description="The path to the model") -class CreateModelRequest (BaseModel): +class CreateModelRequest(BaseModel): name: str = Field(description="The name of the model") - info: Union[CkptModelInfo, DiffusersModelInfo] = Field(..., discriminator="format", description="The model info") + info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") -class CreateModelResponse (BaseModel): +class CreateModelResponse(BaseModel): name: str = Field(description="The name of the new model") - info: Union[CkptModelInfo, DiffusersModelInfo] = Field(..., discriminator="format", description="The model info") + info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info") status: str = Field(description="The status of the API response") -class ConvertedModelRequest (BaseModel): +class ConversionRequest(BaseModel): name: str = Field(description="The name of the new model") info: CkptModelInfo = Field(description="The converted model info") + save_location: str = Field(description="The path to save the converted model weights") + -class ConvertedModelResponse (BaseModel): +class ConvertedModelResponse(BaseModel): name: str = Field(description="The name of the new model") info: DiffusersModelInfo = Field(description="The converted model info") @@ -70,34 +78,22 @@ async def list_models() -> ModelsList: @models_router.post( "/", operation_id="update_model", - responses={ - 201: { - "model_response": "Model added", - }, - 202: { - "description": "Model submission is processing. Check back later." - }, - }, + responses={200: {"status": "success"}}, ) async def update_model( model_request: CreateModelRequest ) -> CreateModelResponse: """ Add Model """ - try: - model_request_info = model_request.info - print(f">> Checking for {model_request_info}...") - info_dict = model_request_info.dict() + model_request_info = model_request.info + info_dict = model_request_info.dict() + model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success") - ApiDependencies.invoker.services.model_manager.add_model( - model_name=model_request.name, - model_attributes=info_dict, - clobber=True, - ) - model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success") + ApiDependencies.invoker.services.model_manager.add_model( + model_name=model_request.name, + model_attributes=info_dict, + clobber=True, + ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - return model_response @@ -106,7 +102,7 @@ async def update_model( operation_id="del_model", responses={ 204: { - "description": "Model deleted" + "description": "Model deleted successfully" }, 404: { "description": "Model not found" @@ -117,103 +113,19 @@ async def delete_model(model_name: str) -> None: """Delete Model""" model_names = ApiDependencies.invoker.services.model_manager.model_names() model_exists = model_name in model_names - - try: - # check if model exists - print(f">> Checking for model {model_name}...") - if not model_exists: - print(f">> Model not found") - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") - - # delete model + # check if model exists + print(f">> Checking for model {model_name}...") + + if model_exists: print(f">> Deleting Model: {model_name}") ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True) print(f">> Model Deleted: {model_name}") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=200, detail=f"Model '{model_name}' deleted successfully") -@models_router.post( - "/{model_to_convert}", - operation_id="convert_model", - responses={ - 201: { - "model_response": "Model converted successfully.", - }, - 202: { - "description": "Model conversion is processing. Check back later." - }, - }, -) -async def convert_model(convert_request = ConvertedModelRequest) -> ConvertedModelResponse: - """ Convert Model """ - try: - convert_request_info = convert_request.info - info_dict = convert_request_info.dict() - convert_request = ConvertedModelRequest(name=convert_request.name, config=info_dict.config, weights=info_dict.weights, description=info_dict.description) - - if model_info := ApiDependencies.invoker.services.model_manager.model_info( - model_name=convert_request.name - ): - if "weights" in model_info: - ckpt_path = Path(convert_request.weights) - original_config_file = Path(convert_request.config) - model_name = convert_request.weights - model_description = convert_request.description - else: - raise HTTPException(status_code=404, detail=f"Model '{convert_request.name}' is not a valid checkpoint model") - else: - raise HTTPException(status_code=404, detail=f"Unable to retrieve model info") - - if not ckpt_path.is_absolute(): - ckpt_path = Path(Globals.root, ckpt_path) - - if original_config_file and not original_config_file.is_absolute(): - original_config_file = Path(Globals.root, original_config_file) - - diffusers_path = Path( - ckpt_path.parent.absolute(), f"{model_name}_diffusers" - ) - - if model_to_convert["save_location"] == "root": - diffusers_path = Path( - global_converted_ckpts_dir(), f"{model_name}_diffusers" - ) - - if ( - model_to_convert["save_location"] == "custom" - and model_to_convert["custom_location"] is not None - ): - diffusers_path = Path( - model_to_convert["custom_location"], f"{model_name}_diffusers" - ) - - if diffusers_path.exists(): - shutil.rmtree(diffusers_path) - - self.generate.model_manager.convert_and_import( - ckpt_path, - diffusers_path, - model_name=model_name, - model_description=model_description, - vae=None, - original_config_file=original_config_file, - commit_to_conf=opt.conf, - ) - - new_model_list = self.generate.model_manager.list_models() - socketio.emit( - "modelConverted", - { - "new_model_name": model_name, - "model_list": new_model_list, - "update": True, - }, - ) - print(f">> Model Converted: {model_name}") - - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + if not model_exists: + print(f">> Model not found") + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") # @socketio.on("convertToDiffusers")