mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Compare commits
2 Commits
ebr/fix-in
...
Convert-Mo
Author | SHA1 | Date | |
---|---|---|---|
efabf250d7 | |||
9ecca13229 |
@ -1,7 +1,7 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and Kent Keirsey (https://github.com/hipsterusername)
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
import asyncio
|
import os
|
||||||
from typing import Annotated, Any, List, Literal, Optional, Union
|
from typing import Annotated, Any, List, Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter, HTTPException
|
||||||
@ -47,10 +47,8 @@ class CreateModelResponse(BaseModel):
|
|||||||
|
|
||||||
class ConversionRequest(BaseModel):
|
class ConversionRequest(BaseModel):
|
||||||
name: str = Field(description="The name of the new model")
|
name: str = Field(description="The name of the new model")
|
||||||
info: CkptModelInfo = Field(description="The converted model info")
|
|
||||||
save_location: str = Field(description="The path to save the converted model weights")
|
save_location: str = Field(description="The path to save the converted model weights")
|
||||||
|
|
||||||
|
|
||||||
class ConvertedModelResponse(BaseModel):
|
class ConvertedModelResponse(BaseModel):
|
||||||
name: str = Field(description="The name of the new model")
|
name: str = Field(description="The name of the new model")
|
||||||
info: DiffusersModelInfo = Field(description="The converted model info")
|
info: DiffusersModelInfo = Field(description="The converted model info")
|
||||||
@ -124,6 +122,95 @@ async def delete_model(model_name: str) -> None:
|
|||||||
logger.error(f"Model not found")
|
logger.error(f"Model not found")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|
||||||
|
# TODO: Refactor these support functions below to live somewhere more appropriate
|
||||||
|
|
||||||
|
def get_model_info(model_name: str):
|
||||||
|
model_info = ApiDependencies.invoker.services.model_manager.model_info(
|
||||||
|
model_name=model_name
|
||||||
|
)
|
||||||
|
if not model_info:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Unable to retrieve model info for '{model_name}'")
|
||||||
|
return model_info
|
||||||
|
|
||||||
|
|
||||||
|
def ckpt_validate(model_info: dict, model_name: str):
|
||||||
|
if "weights" not in model_info:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' is not a valid checkpoint model")
|
||||||
|
|
||||||
|
|
||||||
|
def get_paths(model: ConversionRequest, root: Path) -> tuple:
|
||||||
|
model_info = get_model_info(model.name)
|
||||||
|
ckpt_path = Path(model_info.weights)
|
||||||
|
config_path = Path(model_info.config)
|
||||||
|
|
||||||
|
if not ckpt_path.is_absolute():
|
||||||
|
ckpt_path = Path(root, ckpt_path)
|
||||||
|
|
||||||
|
if config_path and not config_path.is_absolute():
|
||||||
|
config_path = Path(root, config_path)
|
||||||
|
|
||||||
|
return ckpt_path, config_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_diffusers_path(convert_request: ConversionRequest, model_name: str) -> Path:
|
||||||
|
if convert_request.save_location == "root":
|
||||||
|
diffusers_path = Path(global_converted_ckpts_dir(), f"{model_name}_diffusers")
|
||||||
|
elif convert_request.save_location == "custom" and convert_request.save_location is not None:
|
||||||
|
diffusers_path = Path(convert_request.save_location, f"{model_name}_diffusers")
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid save_location value")
|
||||||
|
|
||||||
|
if diffusers_path.exists():
|
||||||
|
shutil.rmtree(diffusers_path)
|
||||||
|
|
||||||
|
return diffusers_path
|
||||||
|
|
||||||
|
|
||||||
|
@models_router.post(
|
||||||
|
"/{model_to_convert}",
|
||||||
|
operation_id="convert_model",
|
||||||
|
responses={
|
||||||
|
200: {
|
||||||
|
"model_response": "Model converted successfully.",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def convert_model(convert_request: ConversionRequest) -> ConvertedModelResponse:
|
||||||
|
"""Convert Model"""
|
||||||
|
|
||||||
|
opt=Args()
|
||||||
|
args = opt.parse_args()
|
||||||
|
|
||||||
|
# Set the root directory for static files and relative paths
|
||||||
|
args.root_dir = os.path.expanduser(args.root_dir or "..")
|
||||||
|
if not os.path.isabs(args.outdir):
|
||||||
|
args.outdir = os.path.join(args.root_dir, args.outdir)
|
||||||
|
|
||||||
|
# normalize the config directory relative to root
|
||||||
|
if not os.path.isabs(opt.conf):
|
||||||
|
opt.conf = os.path.normpath(os.path.join(Globals.root, opt.conf))
|
||||||
|
model_info = get_model_info(convert_request.name)
|
||||||
|
ckpt_validate(model_info, convert_request.name)
|
||||||
|
ckpt_path, original_config_file = get_paths(convert_request, Globals.root)
|
||||||
|
diffusers_path = get_diffusers_path(convert_request, convert_request.name)
|
||||||
|
|
||||||
|
ApiDependencies.invoker.services.model_manager.convert_and_import(
|
||||||
|
ckpt_path,
|
||||||
|
diffusers_path,
|
||||||
|
model_name=convert_request.name,
|
||||||
|
model_description=model_info.description,
|
||||||
|
vae=None,
|
||||||
|
original_config_file=original_config_file,
|
||||||
|
commit_to_conf=opt.conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_info = get_model_info(convert_request.name)
|
||||||
|
convert_response = ConvertedModelResponse(name=f"{convert_request.name}_diffusers", info=model_info)
|
||||||
|
|
||||||
|
print(f">> Model Converted: {convert_request.name}")
|
||||||
|
|
||||||
|
return convert_response
|
||||||
|
|
||||||
|
|
||||||
# @socketio.on("convertToDiffusers")
|
# @socketio.on("convertToDiffusers")
|
||||||
# def convert_to_diffusers(model_to_convert: dict):
|
# def convert_to_diffusers(model_to_convert: dict):
|
||||||
|
Reference in New Issue
Block a user