accept @psychedelicious suggestions above

This commit is contained in:
Lincoln Stein
2023-07-05 14:50:57 -04:00
3 changed files with 108 additions and 53 deletions

View File

@ -2,25 +2,28 @@
from typing import Literal, Optional, Union
from fastapi import Query, Body, Path
from fastapi.routing import APIRouter, HTTPException
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
from invokeai.backend.model_management.models import (MODEL_CONFIGS,
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
class CreateModelResponse(BaseModel):
class UpdateModelResponse(BaseModel):
model_name: str = Field(description="The name of the new model")
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ImportModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
location: str = Field(description="The path, repo_id or URL of the imported model")
info: AddModelResult = Field(description="The model info")
status: str = Field(description="The status of the API response")
class ConvertModelResponse(BaseModel):
name: str = Field(description="The name of the imported model")
@ -48,51 +51,65 @@ async def list_models(
models = parse_obj_as(ModelsList, { "models": models_raw })
return models
@models_router.post(
@models_router.patch(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={200: {"status": "success"}},
responses={200: {"description" : "The model was updated successfully"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
},
status_code = 200,
response_model = UpdateModelResponse,
)
async def update_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"),
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> CreateModelResponse:
) -> UpdateModelResponse:
""" Add Model """
ApiDependencies.invoker.services.model_manager.add_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict(),
clobber=True,
)
model_response = CreateModelResponse(
model_name = model_name,
info = info,
status="success")
try:
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict()
)
model_response = UpdateModelResponse(
model_name = model_name,
info = ApiDependencies.invoker.services.model_manager.model_info(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
)
except KeyError as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return model_response
@models_router.post(
"/import",
"/",
operation_id="import_model",
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
response_model=ImportModelResponse
)
async def import_model(
name: str = Body(description="A model path, repo_id or URL to import"),
location: str = Body(description="A model path, repo_id or URL to import"),
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
) -> ImportModelResponse:
""" Add a model using its local path, repo_id, or remote URL """
items_to_import = {name}
items_to_import = {location}
prediction_types = { x.value: x for x in SchedulerPredictionType }
logger = ApiDependencies.invoker.services.logger
@ -101,12 +118,16 @@ async def import_model(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
if info := installed_models.get(name):
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
info = installed_models.get(location)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
logger.info(f'Successfully imported {location}, got {info}')
return ImportModelResponse(
location = location,
info = info,
status = "success",
)
except KeyError as e:
logger.error(str(e))
@ -129,10 +150,10 @@ async def import_model(
},
)
async def delete_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"),
) -> None:
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
@ -142,14 +163,10 @@ async def delete_model(
model_type = model_type
)
logger.info(f"Deleted model: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
else:
logger.info(f"Model deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):