Update to address feedback

This commit is contained in:
Kent Keirsey 2023-04-06 22:25:18 -04:00
parent 9d80b28a4f
commit 7919d81fb1

View File

@ -1,11 +1,17 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
import shutil
import asyncio
from typing import Annotated, Any, List, Literal, Optional, Union from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path from pathlib import Path
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
from invokeai.backend.globals import Globals, global_converted_ckpts_dir
from invokeai.backend.args import Args
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -34,20 +40,22 @@ class DiffusersModelInfo(ModelInfo):
repo_id: Optional[str] = Field(description="The repo ID to use for this model") repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model") path: Optional[str] = Field(description="The path to the model")
class CreateModelRequest (BaseModel): class CreateModelRequest(BaseModel):
name: str = Field(description="The name of the model") name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(..., discriminator="format", description="The model info") info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
class CreateModelResponse (BaseModel): class CreateModelResponse(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(..., discriminator="format", description="The model info") info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
status: str = Field(description="The status of the API response") status: str = Field(description="The status of the API response")
class ConvertedModelRequest (BaseModel): class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info") info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse (BaseModel): class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
@ -70,34 +78,22 @@ async def list_models() -> ModelsList:
@models_router.post( @models_router.post(
"/", "/",
operation_id="update_model", operation_id="update_model",
responses={ responses={200: {"status": "success"}},
201: {
"model_response": "Model added",
},
202: {
"description": "Model submission is processing. Check back later."
},
},
) )
async def update_model( async def update_model(
model_request: CreateModelRequest model_request: CreateModelRequest
) -> CreateModelResponse: ) -> CreateModelResponse:
""" Add Model """ """ Add Model """
try: model_request_info = model_request.info
model_request_info = model_request.info info_dict = model_request_info.dict()
print(f">> Checking for {model_request_info}...") model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
info_dict = model_request_info.dict()
ApiDependencies.invoker.services.model_manager.add_model( ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_request.name, model_name=model_request.name,
model_attributes=info_dict, model_attributes=info_dict,
clobber=True, clobber=True,
) )
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return model_response return model_response
@ -106,7 +102,7 @@ async def update_model(
operation_id="del_model", operation_id="del_model",
responses={ responses={
204: { 204: {
"description": "Model deleted" "description": "Model deleted successfully"
}, },
404: { 404: {
"description": "Model not found" "description": "Model not found"
@ -117,103 +113,19 @@ async def delete_model(model_name: str) -> None:
"""Delete Model""" """Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names() model_names = ApiDependencies.invoker.services.model_manager.model_names()
model_exists = model_name in model_names model_exists = model_name in model_names
try:
# check if model exists
print(f">> Checking for model {model_name}...")
if not model_exists: # check if model exists
print(f">> Model not found") print(f">> Checking for model {model_name}...")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
if model_exists:
# delete model
print(f">> Deleting Model: {model_name}") print(f">> Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True) ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
print(f">> Model Deleted: {model_name}") print(f">> Model Deleted: {model_name}")
except Exception as e: raise HTTPException(status_code=200, detail=f"Model '{model_name}' deleted successfully")
raise HTTPException(status_code=500, detail=str(e))
@models_router.post( if not model_exists:
"/{model_to_convert}", print(f">> Model not found")
operation_id="convert_model", raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
responses={
201: {
"model_response": "Model converted successfully.",
},
202: {
"description": "Model conversion is processing. Check back later."
},
},
)
async def convert_model(convert_request = ConvertedModelRequest) -> ConvertedModelResponse:
""" Convert Model """
try:
convert_request_info = convert_request.info
info_dict = convert_request_info.dict()
convert_request = ConvertedModelRequest(name=convert_request.name, config=info_dict.config, weights=info_dict.weights, description=info_dict.description)
if model_info := ApiDependencies.invoker.services.model_manager.model_info(
model_name=convert_request.name
):
if "weights" in model_info:
ckpt_path = Path(convert_request.weights)
original_config_file = Path(convert_request.config)
model_name = convert_request.weights
model_description = convert_request.description
else:
raise HTTPException(status_code=404, detail=f"Model '{convert_request.name}' is not a valid checkpoint model")
else:
raise HTTPException(status_code=404, detail=f"Unable to retrieve model info")
if not ckpt_path.is_absolute():
ckpt_path = Path(Globals.root, ckpt_path)
if original_config_file and not original_config_file.is_absolute():
original_config_file = Path(Globals.root, original_config_file)
diffusers_path = Path(
ckpt_path.parent.absolute(), f"{model_name}_diffusers"
)
if model_to_convert["save_location"] == "root":
diffusers_path = Path(
global_converted_ckpts_dir(), f"{model_name}_diffusers"
)
if (
model_to_convert["save_location"] == "custom"
and model_to_convert["custom_location"] is not None
):
diffusers_path = Path(
model_to_convert["custom_location"], f"{model_name}_diffusers"
)
if diffusers_path.exists():
shutil.rmtree(diffusers_path)
self.generate.model_manager.convert_and_import(
ckpt_path,
diffusers_path,
model_name=model_name,
model_description=model_description,
vae=None,
original_config_file=original_config_file,
commit_to_conf=opt.conf,
)
new_model_list = self.generate.model_manager.list_models()
socketio.emit(
"modelConverted",
{
"new_model_name": model_name,
"model_list": new_model_list,
"update": True,
},
)
print(f">> Model Converted: {model_name}")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# @socketio.on("convertToDiffusers") # @socketio.on("convertToDiffusers")