mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
2 Commits
v4.2.6post
...
Convert-Mo
Author | SHA1 | Date | |
---|---|---|---|
efabf250d7 | |||
9ecca13229 |
@ -1,7 +1,7 @@
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and Kent Keirsey (https://github.com/hipsterusername)
|
||||
|
||||
import shutil
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
@ -47,9 +47,7 @@ class CreateModelResponse(BaseModel):
|
||||
|
||||
class ConversionRequest(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
info: CkptModelInfo = Field(description="The converted model info")
|
||||
save_location: str = Field(description="The path to save the converted model weights")
|
||||
|
||||
|
||||
class ConvertedModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the new model")
|
||||
@ -123,8 +121,97 @@ async def delete_model(model_name: str) -> None:
|
||||
else:
|
||||
logger.error(f"Model not found")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
|
||||
# TODO: Refactor these support functions below to live somewhere more appropriate
|
||||
|
||||
def get_model_info(model_name: str):
|
||||
model_info = ApiDependencies.invoker.services.model_manager.model_info(
|
||||
model_name=model_name
|
||||
)
|
||||
if not model_info:
|
||||
raise HTTPException(status_code=404, detail=f"Unable to retrieve model info for '{model_name}'")
|
||||
return model_info
|
||||
|
||||
|
||||
def ckpt_validate(model_info: dict, model_name: str):
|
||||
if "weights" not in model_info:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' is not a valid checkpoint model")
|
||||
|
||||
|
||||
def get_paths(model: ConversionRequest, root: Path) -> tuple:
|
||||
model_info = get_model_info(model.name)
|
||||
ckpt_path = Path(model_info.weights)
|
||||
config_path = Path(model_info.config)
|
||||
|
||||
if not ckpt_path.is_absolute():
|
||||
ckpt_path = Path(root, ckpt_path)
|
||||
|
||||
if config_path and not config_path.is_absolute():
|
||||
config_path = Path(root, config_path)
|
||||
|
||||
return ckpt_path, config_path
|
||||
|
||||
|
||||
def get_diffusers_path(convert_request: ConversionRequest, model_name: str) -> Path:
|
||||
if convert_request.save_location == "root":
|
||||
diffusers_path = Path(global_converted_ckpts_dir(), f"{model_name}_diffusers")
|
||||
elif convert_request.save_location == "custom" and convert_request.save_location is not None:
|
||||
diffusers_path = Path(convert_request.save_location, f"{model_name}_diffusers")
|
||||
else:
|
||||
raise ValueError("Invalid save_location value")
|
||||
|
||||
if diffusers_path.exists():
|
||||
shutil.rmtree(diffusers_path)
|
||||
|
||||
return diffusers_path
|
||||
|
||||
|
||||
@models_router.post(
|
||||
"/{model_to_convert}",
|
||||
operation_id="convert_model",
|
||||
responses={
|
||||
200: {
|
||||
"model_response": "Model converted successfully.",
|
||||
}
|
||||
},
|
||||
)
|
||||
async def convert_model(convert_request: ConversionRequest) -> ConvertedModelResponse:
|
||||
"""Convert Model"""
|
||||
|
||||
opt=Args()
|
||||
args = opt.parse_args()
|
||||
|
||||
# Set the root directory for static files and relative paths
|
||||
args.root_dir = os.path.expanduser(args.root_dir or "..")
|
||||
if not os.path.isabs(args.outdir):
|
||||
args.outdir = os.path.join(args.root_dir, args.outdir)
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.isabs(opt.conf):
|
||||
opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))
|
||||
model_info = get_model_info(convert_request.name)
|
||||
ckpt_validate(model_info, convert_request.name)
|
||||
ckpt_path, original_config_file = get_paths(convert_request, Globals.root)
|
||||
diffusers_path = get_diffusers_path(convert_request, convert_request.name)
|
||||
|
||||
ApiDependencies.invoker.services.model_manager.convert_and_import(
|
||||
ckpt_path,
|
||||
diffusers_path,
|
||||
model_name=convert_request.name,
|
||||
model_description=model_info.description,
|
||||
vae=None,
|
||||
original_config_file=original_config_file,
|
||||
commit_to_conf=opt.conf,
|
||||
)
|
||||
|
||||
model_info = get_model_info(convert_request.name)
|
||||
convert_response = ConvertedModelResponse(name=f"{convert_request.name}_diffusers", info=model_info)
|
||||
|
||||
print(f">> Model Converted: {convert_request.name}")
|
||||
|
||||
return convert_response
|
||||
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
# try:
|
||||
|
Reference in New Issue
Block a user