mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
accept @psychedelicious suggestions above
This commit is contained in:
@ -2,25 +2,28 @@
|
||||
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from fastapi import Query, Body, Path
|
||||
from fastapi.routing import APIRouter, HTTPException
|
||||
from fastapi import Body, Path, Query, Response
|
||||
from fastapi.routing import APIRouter
|
||||
from pydantic import BaseModel, Field, parse_obj_as
|
||||
from ..dependencies import ApiDependencies
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from invokeai.backend import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management import AddModelResult
|
||||
from invokeai.backend.model_management.models import MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS, SchedulerPredictionType
|
||||
from invokeai.backend.model_management.models import (MODEL_CONFIGS,
|
||||
OPENAPI_MODEL_CONFIGS,
|
||||
SchedulerPredictionType)
|
||||
|
||||
from ..dependencies import ApiDependencies
|
||||
|
||||
models_router = APIRouter(prefix="/v1/models", tags=["models"])
|
||||
|
||||
class CreateModelResponse(BaseModel):
|
||||
class UpdateModelResponse(BaseModel):
|
||||
model_name: str = Field(description="The name of the new model")
|
||||
info: Union[tuple(MODEL_CONFIGS)] = Field(description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ImportModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the imported model")
|
||||
location: str = Field(description="The path, repo_id or URL of the imported model")
|
||||
info: AddModelResult = Field(description="The model info")
|
||||
status: str = Field(description="The status of the API response")
|
||||
|
||||
class ConvertModelResponse(BaseModel):
|
||||
name: str = Field(description="The name of the imported model")
|
||||
@ -48,51 +51,65 @@ async def list_models(
|
||||
models = parse_obj_as(ModelsList, { "models": models_raw })
|
||||
return models
|
||||
|
||||
@models_router.post(
|
||||
@models_router.patch(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="update_model",
|
||||
responses={200: {"status": "success"}},
|
||||
responses={200: {"description" : "The model was updated successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
400: {"description" : "Bad request"}
|
||||
},
|
||||
status_code = 200,
|
||||
response_model = UpdateModelResponse,
|
||||
)
|
||||
async def update_model(
|
||||
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
|
||||
model_type: ModelType = Path(default='main', description="The type of model"),
|
||||
model_name: str = Path(default=None, description="model name"),
|
||||
info: Union[tuple(MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> CreateModelResponse:
|
||||
) -> UpdateModelResponse:
|
||||
""" Add Model """
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info.dict(),
|
||||
clobber=True,
|
||||
)
|
||||
model_response = CreateModelResponse(
|
||||
model_name = model_name,
|
||||
info = info,
|
||||
status="success")
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.update_model(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
model_attributes=info.dict()
|
||||
)
|
||||
model_response = UpdateModelResponse(
|
||||
model_name = model_name,
|
||||
info = ApiDependencies.invoker.services.model_manager.model_info(
|
||||
model_name=model_name,
|
||||
base_model=base_model,
|
||||
model_type=model_type,
|
||||
)
|
||||
)
|
||||
except KeyError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/import",
|
||||
"/",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to import successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def import_model(
|
||||
name: str = Body(description="A model path, repo_id or URL to import"),
|
||||
location: str = Body(description="A model path, repo_id or URL to import"),
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
|
||||
items_to_import = {name}
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
@ -101,12 +118,16 @@ async def import_model(
|
||||
items_to_import = items_to_import,
|
||||
prediction_type_helper = lambda x: prediction_types.get(prediction_type)
|
||||
)
|
||||
if info := installed_models.get(name):
|
||||
logger.info(f'Successfully imported {name}, got {info}')
|
||||
return ImportModelResponse(
|
||||
name = name,
|
||||
info = installed_models.get(location)
|
||||
|
||||
if not info:
|
||||
logger.error("Import failed")
|
||||
raise HTTPException(status_code=424)
|
||||
|
||||
logger.info(f'Successfully imported {location}, got {info}')
|
||||
return ImportModelResponse(
|
||||
location = location,
|
||||
info = info,
|
||||
status = "success",
|
||||
)
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
@ -129,10 +150,10 @@ async def import_model(
|
||||
},
|
||||
)
|
||||
async def delete_model(
|
||||
base_model: BaseModelType = Path(default='sd-1', description="Base model"),
|
||||
model_type: ModelType = Path(default='main', description="The type of model"),
|
||||
model_name: str = Path(default=None, description="model name"),
|
||||
) -> None:
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="model name"),
|
||||
) -> Response:
|
||||
"""Delete Model"""
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
@ -142,14 +163,10 @@ async def delete_model(
|
||||
model_type = model_type
|
||||
)
|
||||
logger.info(f"Deleted model: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
return Response(status_code=204)
|
||||
except KeyError:
|
||||
logger.error(f"Model not found: {model_name}")
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
|
||||
else:
|
||||
logger.info(f"Model deleted: {model_name}")
|
||||
raise HTTPException(status_code=204, detail=f"Model '{model_name}' deleted successfully")
|
||||
|
||||
|
||||
# @socketio.on("convertToDiffusers")
|
||||
# def convert_to_diffusers(model_to_convert: dict):
|
||||
|
Reference in New Issue
Block a user