fix(api): improve mm routes

This commit is contained in:
psychedelicious 2023-07-05 20:08:47 +10:00
parent 5d4d0e795c
commit 56d4ea3252

View File

@ -2,13 +2,18 @@
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from fastapi import Query, Body, Path from fastapi import Body, Path, Query, Response
from fastapi.routing import APIRouter, HTTPException from fastapi.routing import APIRouter
from pydantic import BaseModel, Field, parse_obj_as from pydantic import BaseModel, Field, parse_obj_as
from ..dependencies import ApiDependencies from starlette.exceptions import HTTPException
from invokeai.backend import BaseModelType, ModelType from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management import AddModelResult from invokeai.backend.model_management import AddModelResult
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType from invokeai.backend.model_management.models import (MODEL_CONFIGS,
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType)
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -75,6 +80,7 @@ async def update_model(
responses= { responses= {
201: {"description" : "The model imported successfully"}, 201: {"description" : "The model imported successfully"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
409: {"description" : "There is already a model corresponding to this path or repo_id"}, 409: {"description" : "There is already a model corresponding to this path or repo_id"},
}, },
status_code=201, status_code=201,
@ -96,7 +102,12 @@ async def import_model(
items_to_import = items_to_import, items_to_import = items_to_import,
prediction_type_helper = lambda x: prediction_types.get(prediction_type) prediction_type_helper = lambda x: prediction_types.get(prediction_type)
) )
if info := installed_models.get(name): info = installed_models.get(name)
if not info:
logger.error("Import failed")
raise HTTPException(status_code=424)
logger.info(f'Successfully imported {name}, got {info}') logger.info(f'Successfully imported {name}, got {info}')
return ImportModelResponse( return ImportModelResponse(
name = name, name = name,
@ -124,10 +135,10 @@ async def import_model(
}, },
) )
async def delete_model( async def delete_model(
base_model: BaseModelType = Path(default='sd-1', description="Base model"), base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(default='main', description="The type of model"), model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(default=None, description="model name"), model_name: str = Path(description="model name"),
) -> None: ) -> Response:
"""Delete Model""" """Delete Model"""
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
@ -137,14 +148,10 @@ async def delete_model(
model_type = model_type model_type = model_type
) )
logger.info(f"Deleted model: {model_name}") logger.info(f"Deleted model: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully") return Response(status_code=204)
except KeyError: except KeyError:
logger.error(f"Model not found: {model_name}") logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
else:
logger.info(f"Model deleted: {model_name}")
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
# @socketio.on("convertToDiffusers") # @socketio.on("convertToDiffusers")
# def convert_to_diffusers(model_to_convert: dict): # def convert_to_diffusers(model_to_convert: dict):