Add/Update and Delete Models

This commit is contained in:
Kent Keirsey 2023-04-06 15:17:48 -04:00
parent e456e2e63a
commit 1fcd91bcc5

View File

@ -2,9 +2,9 @@
from typing import Annotated, Any, List, Literal, Optional, Union from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from pathlib import Path
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -15,11 +15,9 @@ class VaeRepo(BaseModel):
path: Optional[str] = Field(description="The path to the VAE") path: Optional[str] = Field(description="The path to the VAE")
subfolder: Optional[str] = Field(description="The subfolder to use for this VAE") subfolder: Optional[str] = Field(description="The subfolder to use for this VAE")
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
description: Optional[str] = Field(description="A description of the model") description: Optional[str] = Field(description="A description of the model")
class CkptModelInfo(ModelInfo): class CkptModelInfo(ModelInfo):
format: Literal['ckpt'] = 'ckpt' format: Literal['ckpt'] = 'ckpt'
@ -29,7 +27,6 @@ class CkptModelInfo(ModelInfo):
width: Optional[int] = Field(description="The width of the model") width: Optional[int] = Field(description="The width of the model")
height: Optional[int] = Field(description="The height of the model") height: Optional[int] = Field(description="The height of the model")
class DiffusersModelInfo(ModelInfo): class DiffusersModelInfo(ModelInfo):
format: Literal['diffusers'] = 'diffusers' format: Literal['diffusers'] = 'diffusers'
@ -37,12 +34,27 @@ class DiffusersModelInfo(ModelInfo):
repo_id: Optional[str] = Field(description="The repo ID to use for this model") repo_id: Optional[str] = Field(description="The repo ID to use for this model")
path: Optional[str] = Field(description="The path to the model") path: Optional[str] = Field(description="The path to the model")
class CreateModelRequest (BaseModel):
name: str = Field(description="The name of the model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(..., discriminator="format", description="The model info")
class CreateModelResponse (BaseModel):
name: str = Field(description="The name of the new model")
info: Union[CkptModelInfo, DiffusersModelInfo] = Field(..., discriminator="format", description="The model info")
status: str = Field(description="The status of the API response")
class ConvertedModelRequest (BaseModel):
name: str = Field(description="The name of the new model")
info: CkptModelInfo = Field(description="The converted model info")
class ConvertedModelResponse (BaseModel):
name: str = Field(description="The name of the new model")
info: DiffusersModelInfo = Field(description="The converted model info")
class ModelsList(BaseModel): class ModelsList(BaseModel):
models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]] models: dict[str, Annotated[Union[(CkptModelInfo,DiffusersModelInfo)], Field(discriminator="format")]]
@models_router.get( @models_router.get(
"/", "/",
operation_id="list_models", operation_id="list_models",
@ -54,226 +66,70 @@ async def list_models() -> ModelsList:
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, { "models": models_raw })
return models return models
# @socketio.on("requestSystemConfig")
# def handle_request_capabilities():
# print(">> System config requested")
# config = self.get_system_config()
# config["model_list"] = self.generate.model_manager.list_models()
# config["infill_methods"] = infill_methods()
# socketio.emit("systemConfig", config)
# @socketio.on("searchForModels") @models_router.post(
# def handle_search_models(search_folder: str): "/",
# try: operation_id="update_model",
# if not search_folder: responses={
# socketio.emit( 201: {
# "foundModels", "model_response": "Model added",
# {"search_folder": None, "found_models": None}, },
# ) 202: {
# else: "description": "Model submission is processing. Check back later."
# ( },
# search_folder, },
# found_models, )
# ) = self.generate.model_manager.search_models(search_folder) async def update_model(
# socketio.emit( model_request: CreateModelRequest
# "foundModels", ) -> CreateModelResponse:
# {"search_folder": search_folder, "found_models": found_models}, """ Add Model """
# ) try:
# except Exception as e: model_request_info = model_request.info
# self.handle_exceptions(e) print(f">> Checking for {model_request_info}...")
# print("\n") info_dict = model_request_info.dict()
# @socketio.on("addNewModel") ApiDependencies.invoker.services.model_manager.add_model(
# def handle_add_model(new_model_config: dict): model_name=model_request.name,
# try: model_attributes=info_dict,
# model_name = new_model_config["name"] clobber=True,
# del new_model_config["name"] )
# model_attributes = new_model_config model_response = CreateModelResponse(name=model_request.name, info=model_request.info, status="success")
# if len(model_attributes["vae"]) == 0:
# del model_attributes["vae"]
# update = False
# current_model_list = self.generate.model_manager.list_models()
# if model_name in current_model_list:
# update = True
# print(f">> Adding New Model: {model_name}") except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
return model_response
# self.generate.model_manager.add_model(
# model_name=model_name,
# model_attributes=model_attributes,
# clobber=True,
# )
# self.generate.model_manager.commit(opt.conf)
# new_model_list = self.generate.model_manager.list_models() @models_router.delete(
# socketio.emit( "/{model_name}",
# "newModelAdded", operation_id="del_model",
# { responses={
# "new_model_name": model_name, 204: {
# "model_list": new_model_list, "description": "Model deleted"
# "update": update, },
# }, 404: {
# ) "description": "Model not found"
# print(f">> New Model Added: {model_name}") }
# except Exception as e: },
# self.handle_exceptions(e) )
async def delete_model(model_name: str) -> None:
"""Delete Model"""
model_names = ApiDependencies.invoker.services.model_manager.model_names()
model_exists = model_name in model_names
try:
# check if model exists
print(f">> Checking for model {model_name}...")
# @socketio.on("deleteModel") if not model_exists:
# def handle_delete_model(model_name: str): print(f">> Model not found")
# try: raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# print(f">> Deleting Model: {model_name}")
# self.generate.model_manager.del_model(model_name) # delete model
# self.generate.model_manager.commit(opt.conf) print(f">> Deleting Model: {model_name}")
# updated_model_list = self.generate.model_manager.list_models() ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
# socketio.emit( print(f">> Model Deleted: {model_name}")
# "modelDeleted", except Exception as e:
# { raise HTTPException(status_code=500, detail=str(e))
# "deleted_model_name": model_name,
# "model_list": updated_model_list,
# },
# )
# print(f">> Model Deleted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("requestModelChange")
# def handle_set_model(model_name: str):
# try:
# print(f">> Model change requested: {model_name}")
# model = self.generate.set_model(model_name)
# model_list = self.generate.model_manager.list_models()
# if model is None:
# socketio.emit(
# "modelChangeFailed",
# {"model_name": model_name, "model_list": model_list},
# )
# else:
# socketio.emit(
# "modelChanged",
# {"model_name": model_name, "model_list": model_list},
# )
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):
# try:
# if model_info := self.generate.model_manager.model_info(
# model_name=model_to_convert["model_name"]
# ):
# if "weights" in model_info:
# ckpt_path = Path(model_info["weights"])
# original_config_file = Path(model_info["config"])
# model_name = model_to_convert["model_name"]
# model_description = model_info["description"]
# else:
# self.socketio.emit(
# "error", {"message": "Model is not a valid checkpoint file"}
# )
# else:
# self.socketio.emit(
# "error", {"message": "Could not retrieve model info."}
# )
# if not ckpt_path.is_absolute():
# ckpt_path = Path(Globals.root, ckpt_path)
# if original_config_file and not original_config_file.is_absolute():
# original_config_file = Path(Globals.root, original_config_file)
# diffusers_path = Path(
# ckpt_path.parent.absolute(), f"{model_name}_diffusers"
# )
# if model_to_convert["save_location"] == "root":
# diffusers_path = Path(
# global_converted_ckpts_dir(), f"{model_name}_diffusers"
# )
# if (
# model_to_convert["save_location"] == "custom"
# and model_to_convert["custom_location"] is not None
# ):
# diffusers_path = Path(
# model_to_convert["custom_location"], f"{model_name}_diffusers"
# )
# if diffusers_path.exists():
# shutil.rmtree(diffusers_path)
# self.generate.model_manager.convert_and_import(
# ckpt_path,
# diffusers_path,
# model_name=model_name,
# model_description=model_description,
# vae=None,
# original_config_file=original_config_file,
# commit_to_conf=opt.conf,
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelConverted",
# {
# "new_model_name": model_name,
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Model Converted: {model_name}")
# except Exception as e:
# self.handle_exceptions(e)
# @socketio.on("mergeDiffusersModels")
# def merge_diffusers_models(model_merge_info: dict):
# try:
# models_to_merge = model_merge_info["models_to_merge"]
# model_ids_or_paths = [
# self.generate.model_manager.model_name_or_path(x)
# for x in models_to_merge
# ]
# merged_pipe = merge_diffusion_models(
# model_ids_or_paths,
# model_merge_info["alpha"],
# model_merge_info["interp"],
# model_merge_info["force"],
# )
# dump_path = global_models_dir() / "merged_models"
# if model_merge_info["model_merge_save_path"] is not None:
# dump_path = Path(model_merge_info["model_merge_save_path"])
# os.makedirs(dump_path, exist_ok=True)
# dump_path = dump_path / model_merge_info["merged_model_name"]
# merged_pipe.save_pretrained(dump_path, safe_serialization=1)
# merged_model_config = dict(
# model_name=model_merge_info["merged_model_name"],
# description=f'Merge of models {", ".join(models_to_merge)}',
# commit_to_conf=opt.conf,
# )
# if vae := self.generate.model_manager.config[models_to_merge[0]].get(
# "vae", None
# ):
# print(f">> Using configured VAE assigned to {models_to_merge[0]}")
# merged_model_config.update(vae=vae)
# self.generate.model_manager.import_diffuser_model(
# dump_path, **merged_model_config
# )
# new_model_list = self.generate.model_manager.list_models()
# socketio.emit(
# "modelsMerged",
# {
# "merged_models": models_to_merge,
# "merged_model_name": model_merge_info["merged_model_name"],
# "model_list": new_model_list,
# "update": True,
# },
# )
# print(f">> Models Merged: {models_to_merge}")
# print(f">> New Model Added: {model_merge_info['merged_model_name']}")
# except Exception as e:
# self.handle_exceptions(e)