InvokeAI/invokeai/app/api/routers/models.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

263 lines
11 KiB
Python
Raw Normal View History

2023-04-07 02:25:18 +00:00
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654) and 2023 Kent Keirsey (https://github.com/hipsterusername)
2023-06-11 03:12:21 +00:00
from typing import Annotated, Literal, Optional, Union, Dict
from fastapi import Query
2023-04-06 19:17:48 +00:00
from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
2023-06-11 03:12:21 +00:00
from invokeai.backend import BaseModelType, ModelType
2023-06-17 14:15:36 +00:00
from invokeai.backend.model_management.models import OPENAPI_MODEL_CONFIGS
MODEL_CONFIGS = Union[tuple(OPENAPI_MODEL_CONFIGS)]
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class VaeRepo(BaseModel):
repo_id: str = Field(description="The repo ID to use for this VAE")
path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model")
model_name: str = Field(description="The name of the model")
model_type: str = Field(description="The type of the model")
class DiffusersModelInfo(ModelInfo):
format: Literal['folder'] = 'folder'
vae: Optional[VaeRepo] = Field(description="The VAE repo to use for this model")
repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model")
class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt'
config: str = Field(description="The path to the model config")
weights: str = Field(description="The path to the model weights")
vae: str = Field(description="The path to the model VAE")
width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model")
class SafetensorsModelInfo(CkptModelInfo):
format: Literal['safetensors'] = 'safetensors'
2023-04-07 02:25:18 +00:00
class CreateModelRequest(BaseModel):
2023-04-06 19:17:48 +00:00
name: str = Field(description="The name of the model")
2023-04-07 02:25:18 +00:00
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
2023-04-06 19:17:48 +00:00
2023-04-07 02:25:18 +00:00
class CreateModelResponse(BaseModel):
2023-04-06 19:17:48 +00:00
name: str = Field(description="The name of the new model")
2023-04-07 02:25:18 +00:00
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(discriminator="format", description="The model info")
2023-04-06 19:17:48 +00:00
status: str = Field(description="The status of the API response")
2023-04-07 02:25:18 +00:00
class ConversionRequest(BaseModel):
2023-04-06 19:17:48 +00:00
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
2023-04-07 02:25:18 +00:00
save_location: str = Field(description="The path to save the converted model weights")
2023-04-06 19:17:48 +00:00
2023-04-07 02:25:18 +00:00
class ConvertedModelResponse(BaseModel):
2023-04-06 19:17:48 +00:00
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel):
models: Dict[BaseModelType, Dict[ModelType, Dict[str, MODEL_CONFIGS]]] # TODO: debug/discuss with frontend
2023-06-11 03:12:21 +00:00
#models: dict[SDModelType, dict[str, Annotated[Union[(DiffusersModelInfo,CkptModelInfo,SafetensorsModelInfo)], Field(discriminator="format")]]]
@models_router.get(
"/",
operation_id="list_models",
responses={200: {"model": ModelsList }},
)
async def list_models(
2023-06-11 03:12:21 +00:00
base_model: BaseModelType = Query(
default=None, description="Base model"
),
model_type: ModelType = Query(
default=None, description="The type of model to get"
),
) -> ModelsList:
"""Gets a list of models"""
models_raw = ApiDependencies.invoker.services.model_manager.list_models(base_model, model_type)
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
2023-04-06 19:17:48 +00:00
@models_router.post(
"/",
operation_id="update_model",
2023-04-07 02:25:18 +00:00
responses={200: {"status": "success"}},
2023-04-06 19:17:48 +00:00
)
async def update_model(
model_request: CreateModelRequest
) -> CreateModelResponse:
""" Add Model """
2023-04-07 02:25:18 +00:00
model_request_info = model_request.info
info_dict = model_request_info.dict()
model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_request.name,
model_attributes=info_dict,
clobber=True,
)
2023-04-06 19:17:48 +00:00
return model_response
@models_router.delete(
"/{model_name}",
operation_id="del_model",
responses={
204: {
2023-04-07 02:25:18 +00:00
"description": "Model deleted successfully"
2023-04-06 19:17:48 +00:00
},
404: {
"description": "Model not found"
}
},
)
async def delete_model(model_name: str) -> None:
"""Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
2023-04-29 14:48:50 +00:00
logger = ApiDependencies.invoker.services.logger
2023-04-06 19:17:48 +00:00
model_exists = model_name in model_names
2023-04-07 02:25:18 +00:00
# check if model exists
2023-04-29 13:43:40 +00:00
logger.info(f"Checking for model {model_name}...")
2023-04-07 02:25:18 +00:00
if model_exists:
2023-04-29 13:43:40 +00:00
logger.info(f"Deleting Model: {model_name}")
2023-04-06 19:17:48 +00:00
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
2023-04-29 13:43:40 +00:00
logger.info(f"Model Deleted: {model_name}")
2023-04-07 02:26:28 +00:00
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
2023-04-06 19:17:48 +00:00
2023-04-08 02:25:30 +00:00
else:
logger.error("Model not found")
2023-04-07 02:25:18 +00:00
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
2023-04-06 20:23:09 +00:00
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e: