mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(api): improve mm routes
This commit is contained in:
parent
5d4d0e795c
commit
56d4ea3252
@ -2,13 +2,18 @@
|
|||||||
|
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from fastapi import Query, Body, Path
|
from fastapi import Body, Path, Query, Response
|
||||||
from fastapi.routing import APIRouter, HTTPException
|
from fastapi.routing import APIRouter
|
||||||
from pydantic import BaseModel, Field, parse_obj_as
|
from pydantic import BaseModel, Field, parse_obj_as
|
||||||
from ..dependencies import ApiDependencies
|
from starlette.exceptions import HTTPException
|
||||||
|
|
||||||
from invokeai.backend import BaseModelType, ModelType
|
from invokeai.backend import BaseModelType, ModelType
|
||||||
from invokeai.backend.model_management import AddModelResult
|
from invokeai.backend.model_management import AddModelResult
|
||||||
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
from invokeai.backend.model_management.models import (MODEL_CONFIGS,
|
||||||
|
OPENAPI_MODEL_CONFIGS,
|
||||||
|
SchedulerPredictionType)
|
||||||
|
|
||||||
|
from ..dependencies import ApiDependencies
|
||||||
|
|
||||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||||
|
|
||||||
@ -75,6 +80,7 @@ async def update_model(
|
|||||||
responses= {
|
responses= {
|
||||||
201: {"description" : "The model imported successfully"},
|
201: {"description" : "The model imported successfully"},
|
||||||
404: {"description" : "The model could not be found"},
|
404: {"description" : "The model could not be found"},
|
||||||
|
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||||
},
|
},
|
||||||
status_code=201,
|
status_code=201,
|
||||||
@ -96,7 +102,12 @@ async def import_model(
|
|||||||
items_to_import = items_to_import,
|
items_to_import = items_to_import,
|
||||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||||
)
|
)
|
||||||
if info := installed_models.get(name):
|
info = installed_models.get(name)
|
||||||
|
|
||||||
|
if not info:
|
||||||
|
logger.error("Import failed")
|
||||||
|
raise HTTPException(status_code=424)
|
||||||
|
|
||||||
logger.info(f'Successfully imported {name}, got {info}')
|
logger.info(f'Successfully imported {name}, got {info}')
|
||||||
return ImportModelResponse(
|
return ImportModelResponse(
|
||||||
name = name,
|
name = name,
|
||||||
@ -124,10 +135,10 @@ async def import_model(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
async def delete_model(
|
async def delete_model(
|
||||||
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
|
base_model: BaseModelType = Path(description="Base model"),
|
||||||
model_type: ModelType = Path(default='main', description="The type of model"),
|
model_type: ModelType = Path(description="The type of model"),
|
||||||
model_name: str = Path(default=None, description="model name"),
|
model_name: str = Path(description="model name"),
|
||||||
) -> None:
|
) -> Response:
|
||||||
"""Delete Model"""
|
"""Delete Model"""
|
||||||
logger = ApiDependencies.invoker.services.logger
|
logger = ApiDependencies.invoker.services.logger
|
||||||
|
|
||||||
@ -137,14 +148,10 @@ async def delete_model(
|
|||||||
model_type = model_type
|
model_type = model_type
|
||||||
)
|
)
|
||||||
logger.info(f"Deleted model: {model_name}")
|
logger.info(f"Deleted model: {model_name}")
|
||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
return Response(status_code=204)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
logger.error(f"Model not found: {model_name}")
|
logger.error(f"Model not found: {model_name}")
|
||||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||||
else:
|
|
||||||
logger.info(f"Model deleted: {model_name}")
|
|
||||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
|
||||||
|
|
||||||
|
|
||||||
# @socketio.on("convertToDiffusers")
|
# @socketio.on("convertToDiffusers")
|
||||||
# def convert_to_diffusers(model_to_convert: dict):
|
# def convert_to_diffusers(model_to_convert: dict):
|
||||||
|
Loading…
Reference in New Issue
Block a user