Compare commits

...

2 Commits

Author SHA1 Message Date
efabf250d7 Merge branch 'main' into Convert-Model-Endpoint 2023-05-18 18:51:38 -04:00
9ecca13229 Add Convert Model Endpoint 2023-04-08 18:05:21 -04:00

View File

@ -1,7 +1,7 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername) # Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and Kent Keirsey (https://github.com/hipsterusername)
import shutil import shutil
import asyncio import os
from typing import Annotated, Any, List, Literal, Optional, Union from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter, HTTPException
@ -47,10 +47,8 @@ class CreateModelResponse(BaseModel):
class ConversionRequest(BaseModel): class ConversionRequest(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
save_location: str = Field(description="The path to save the converted model weights") save_location: str = Field(description="The path to save the converted model weights")
class ConvertedModelResponse(BaseModel): class ConvertedModelResponse(BaseModel):
name: str = Field(description="The name of the new model") name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info") info: DiffusersModelInfo = Field(description="The converted model info")
@ -124,6 +122,95 @@ async def delete_model(model_name: str) -> None:
logger.error(f"Model not found") logger.error(f"Model not found")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# TODO: Refactor these support functions below to live somewhere more appropriate
def get_model_info(model_name: str):
model_info = ApiDependencies.invoker.services.model_manager.model_info(
model_name=model_name
)
if not model_info:
raise HTTPException(status_code=404, detail=f"Unable to retrieve model info for '{model_name}'")
return model_info
def ckpt_validate(model_info: dict, model_name: str):
if "weights" not in model_info:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' is not a valid checkpoint model")
def get_paths(model: ConversionRequest, root: Path) -> tuple:
model_info = get_model_info(model.name)
ckpt_path = Path(model_info.weights)
config_path = Path(model_info.config)
if not ckpt_path.is_absolute():
ckpt_path = Path(root, ckpt_path)
if config_path and not config_path.is_absolute():
config_path = Path(root, config_path)
return ckpt_path, config_path
def get_diffusers_path(convert_request: ConversionRequest, model_name: str) -> Path:
if convert_request.save_location == "root":
diffusers_path = Path(global_converted_ckpts_dir(), f"{model_name}_diffusers")
elif convert_request.save_location == "custom" and convert_request.save_location is not None:
diffusers_path = Path(convert_request.save_location, f"{model_name}_diffusers")
else:
raise ValueError("Invalid save_location value")
if diffusers_path.exists():
shutil.rmtree(diffusers_path)
return diffusers_path
@models_router.post(
"/{model_to_convert}",
operation_id="convert_model",
responses={
200: {
"model_response": "Model converted successfully.",
}
},
)
async def convert_model(convert_request: ConversionRequest) -> ConvertedModelResponse:
"""Convert Model"""
opt=Args()
args = opt.parse_args()
# Set the root directory for static files and relative paths
args.root_dir = os.path.expanduser(args.root_dir or "..")
if not os.path.isabs(args.outdir):
args.outdir = os.path.join(args.root_dir, args.outdir)
# normalize the config directory relative to root
if not os.path.isabs(opt.conf):
opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))
model_info = get_model_info(convert_request.name)
ckpt_validate(model_info, convert_request.name)
ckpt_path, original_config_file = get_paths(convert_request, Globals.root)
diffusers_path = get_diffusers_path(convert_request, convert_request.name)
ApiDependencies.invoker.services.model_manager.convert_and_import(
ckpt_path,
diffusers_path,
model_name=convert_request.name,
model_description=model_info.description,
vae=None,
original_config_file=original_config_file,
commit_to_conf=opt.conf,
)
model_info = get_model_info(convert_request.name)
convert_response = ConvertedModelResponse(name=f"{convert_request.name}_diffusers", info=model_info)
print(f">> Model Converted: {convert_request.name}")
return convert_response
# @socketio.on("convertToDiffusers") # @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict): # def convert_to_diffusers(model_to_convert: dict):