fix(api): improve mm routes

This commit is contained in:
psychedelicious 2023-07-05 20:08:47 +10:00
parent 5d4d0e795c
commit 56d4ea3252

View File

@ -2,13 +2,18 @@
from typing import Literal, Optional, Union
from fastapi import Query, Body, Path
from fastapi.routing import APIRouter, HTTPException
from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies
from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
from invokeai.backend.model_management.models import (MODEL_CONFIGS,
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -75,6 +80,7 @@ async def update_model(
responses= {
201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"},
},
status_code=201,
@ -96,7 +102,12 @@ async def import_model(
items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
)
if info := installed_models.get(name):
info = installed_models.get(name)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse(
name = name,
@ -124,10 +135,10 @@ async def import_model(
},
)
async def delete_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"),
model_name: str = Path(default=None, description="model name"),
) -> None:
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="model name"),
) -> Response:
"""Delete Model"""
logger = ApiDependencies.invoker.services.logger
@ -137,14 +148,10 @@ async def delete_model(
model_type = model_type
)
logger.info(f"Deleted model: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
else:
logger.info(f"Model deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
# @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict):