Adds (Untested) Add Delete Load Endpoints

This commit is contained in:
Kent Keirsey
2023-03-25 21:51:38 -04:00
parent 6263cb945c
commit b52b9985bd

View File

@ -2,7 +2,7 @@
from typing import Annotated, Any, List, Literal, Optional, Union from typing import Annotated, Any, List, Literal, Optional, Union
from fastapi.routing import APIRouter from fastapi.routing import APIRouter, HTTPException
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
@ -17,6 +17,7 @@ class VaeRepo(BaseModel):
class ModelInfo(BaseModel): class ModelInfo(BaseModel):
model_name: str = Field(..., description="The name of the model")
description: Optional[str] = Field(description="A description of the model") description: Optional[str] = Field(description="A description of the model")
@ -54,6 +55,90 @@ async def list_models() -> ModelsList:
models = parse_obj_as(ModelsList, { "models": models_raw }) models = parse_obj_as(ModelsList, { "models": models_raw })
return models return models
# Kent wrote the below code. It is highly suspect, and should be reviewed. Once all issues have been identified and fixed, this comment can be removed.
# Seriously, check this. I have no idea how to test it.
""" Add Model """
@models_router.post(
"/",
operation_id="add_model",
responses={201: {"model": Union[CkptModelInfo, DiffusersModelInfo], "new_model_list": ModelsList}},
)
async def add_model(
model_info: Union[CkptModelInfo, DiffusersModelInfo],
) -> Union[CkptModelInfo, DiffusersModelInfo]:
"""Adds a new model"""
try:
model_name = model_info["model_name"]
del model_info["model_name"]
model_attributes = model_info
if len(model_attributes.get("vae", [])) == 0:
del model_attributes["vae"]
ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_name,
model_attributes=model_attributes,
clobber=True,
)
# How does Ckpt support deprecation change the above?
except Exception as e:
# Handle any exceptions thrown during the execution of the method
# or raise the exception to be handled by the global exception handler
raise HTTPException(status_code=500, detail=str(e))
return model_info
""" Delete Model """
@models_router.delete(
"/{model_name}",
operation_id="del_model",
responses={204: {"description": "Model deleted"}, 404: {"description": "Model not found"}},
)
async def delete_model(model_name: str) -> None:
try:
# check if model exists
if model_name not in ApiDependencies.invoker.services.model_manager.models:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# delete model
print(f">> Deleting Model: {model_name}")
ApiDependencies.invoker.services.model_manager.del_model(model_name, delete_files=True)
print(f">> Model Deleted: {model_name}")
except Exception as e:
# Handle any exceptions thrown during the execution of the method
raise HTTPException(status_code=500, detail=str(e))
""" Load Model """
models_router.post(
"/load/{model_name}",
operation_id="load_model",
responses={200: {"model": Union[CkptModelInfo, DiffusersModelInfo]}, 404: {"description": "Model not found"}},
)
async def load_model(model_name: str) -> Union[CkptModelInfo, DiffusersModelInfo]:
"""
Load an existing model by name
"""
try:
# check if model exists
if model_name not in ApiDependencies.invoker.services.model_manager.models:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
# load model
model_info = ApiDependencies.invoker.services.model_manager.load_model(model_name)
print(f">> Model Loaded: {model_name}")
return model_info
except Exception as e:
# Handle any exceptions thrown during the execution of the method
raise HTTPException(status_code=500, detail=str(e))
# @socketio.on("requestSystemConfig") # @socketio.on("requestSystemConfig")
# def handle_request_capabilities(): # def handle_request_capabilities():
# print(">> System config requested") # print(">> System config requested")