mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add API for model convert
This commit is contained in:
parent
5027d0a603
commit
5b6dd47b9f
@ -9,9 +9,11 @@ from starlette.exceptions import HTTPException
|
|||||||
|
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management import AddModelResult
|
from invokeai.backend.model_management import AddModelResult
|
||||||
from invokeai.backend.model_management.models import (MODEL_CONFIGS,
|
from invokeai.backend.model_management.models import (
|
||||||
|
MODEL_CONFIGS,
|
||||||
OPENAPI_MODEL_CONFIGS,
|
OPENAPI_MODEL_CONFIGS,
|
||||||
SchedulerPredictionType)
|
SchedulerPredictionType
|
||||||
|
)
|
||||||
|
|
||||||
from ..dependencies import ApiDependencies
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
@ -25,11 +27,6 @@ class ImportModelResponse(BaseModel):
|
|||||||
location: str = Field(description="The path, repo_id or URL of the imported model")
|
location: str = Field(description="The path, repo_id or URL of the imported model")
|
||||||
info: AddModelResult = Field(description="The model info")
|
info: AddModelResult = Field(description="The model info")
|
||||||
|
|
||||||
class ConvertModelResponse(BaseModel):
|
|
||||||
name: str = Field(description="The name of the imported model")
|
|
||||||
info: AddModelResult = Field(description="The model info")
|
|
||||||
status: str = Field(description="The status of the API response")
|
|
||||||
|
|
||||||
class ModelsList(BaseModel):
|
class ModelsList(BaseModel):
|
||||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||||
|
|
||||||
@ -168,74 +165,35 @@ async def delete_model(
|
|||||||
logger.error(f"Model not found: {model_name}")
|
logger.error(f"Model not found: {model_name}")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
|
|
||||||
# @socketio.on("convertToDiffusers")
|
@models_router.patch(
|
||||||
# def convert_to_diffusers(model_to_convert: dict):
|
"/convert/{base_model}/{model_type}/{model_name}",
|
||||||
# try:
|
operation_id="convert_model",
|
||||||
# if model_info := self.generate.model_manager.model_info(
|
responses={
|
||||||
# model_name=model_to_convert["model_name"]
|
200: { "description": "Model converted successfully" },
|
||||||
# ):
|
400: {"description" : "Bad request" },
|
||||||
# if "weights" in model_info:
|
404: { "description": "Model not found" },
|
||||||
# ckpt_path = Path(model_info["weights"])
|
},
|
||||||
# original_config_file = Path(model_info["config"])
|
status_code = 200,
|
||||||
# model_name = model_to_convert["model_name"]
|
response_model = Union[tuple(MODEL_CONFIGS)],
|
||||||
# model_description = model_info["description"]
|
)
|
||||||
# else:
|
async def convert_model(
|
||||||
# self.socketio.emit(
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
# "error", {"message": "Model is not a valid checkpoint file"}
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
# )
|
model_name: str = Path(description="model name"),
|
||||||
# else:
|
) -> Union[tuple(MODEL_CONFIGS)]:
|
||||||
# self.socketio.emit(
|
"""Convert a checkpoint model into a diffusers model"""
|
||||||
# "error", {"message": "Could not retrieve model info."}
|
logger = ApiDependencies.invoker.services.logger
|
||||||
# )
|
try:
|
||||||
|
logger.info(f"Converting model: {model_name}")
|
||||||
# if not ckpt_path.is_absolute():
|
result = ApiDependencies.invoker.services.model_manager.convert_model(model_name,
|
||||||
# ckpt_path = Path(Globals.root, ckpt_path)
|
base_model = base_model,
|
||||||
|
model_type = model_type
|
||||||
# if original_config_file and not original_config_file.is_absolute():
|
)
|
||||||
# original_config_file = Path(Globals.root, original_config_file)
|
except KeyError:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
# diffusers_path = Path(
|
except ValueError as e:
|
||||||
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
# )
|
return result.config
|
||||||
|
|
||||||
# if model_to_convert["save_location"] == "root":
|
|
||||||
# diffusers_path = Path(
|
|
||||||
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if (
|
|
||||||
# model_to_convert["save_location"] == "custom"
|
|
||||||
# and model_to_convert["custom_location"] is not None
|
|
||||||
# ):
|
|
||||||
# diffusers_path = Path(
|
|
||||||
# model_to_convert["custom_location"], f"{model_name}_diffusers"
|
|
||||||
# )
|
|
||||||
|
|
||||||
# if diffusers_path.exists():
|
|
||||||
# shutil.rmtree(diffusers_path)
|
|
||||||
|
|
||||||
# self.generate.model_manager.convert_and_import(
|
|
||||||
# ckpt_path,
|
|
||||||
# diffusers_path,
|
|
||||||
# model_name=model_name,
|
|
||||||
# model_description=model_description,
|
|
||||||
# vae=None,
|
|
||||||
# original_config_file=original_config_file,
|
|
||||||
# commit_to_conf=opt.conf,
|
|
||||||
# )
|
|
||||||
|
|
||||||
# new_model_list = self.generate.model_manager.list_models()
|
|
||||||
# socketio.emit(
|
|
||||||
# "modelConverted",
|
|
||||||
# {
|
|
||||||
# "new_model_name": model_name,
|
|
||||||
# "model_list": new_model_list,
|
|
||||||
# "update": True,
|
|
||||||
# },
|
|
||||||
# )
|
|
||||||
# print(f">> Model Converted: {model_name}")
|
|
||||||
# except Exception as e:
|
|
||||||
# self.handle_exceptions(e)
|
|
||||||
|
|
||||||
# @socketio.on("mergeDiffusersModels")
|
# @socketio.on("mergeDiffusersModels")
|
||||||
# def merge_diffusers_models(model_merge_info: dict):
|
# def merge_diffusers_models(model_merge_info: dict):
|
||||||
|
@ -651,7 +651,7 @@ class ModelManager(object):
|
|||||||
)
|
)
|
||||||
checkpoint_path = self.app_config.root_path / info["path"]
|
checkpoint_path = self.app_config.root_path / info["path"]
|
||||||
old_diffusers_path = self.app_config.models_path / model.location
|
old_diffusers_path = self.app_config.models_path / model.location
|
||||||
new_diffusers_path = self.app_config.models_path / base_model / model_type / model_name
|
new_diffusers_path = self.app_config.models_path / base_model.value / model_type.value / model_name
|
||||||
if new_diffusers_path.exists():
|
if new_diffusers_path.exists():
|
||||||
raise ValueError(f"A diffusers model already exists at {new_path}")
|
raise ValueError(f"A diffusers model already exists at {new_path}")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user