mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add rename_model route
This commit is contained in:
@ -23,6 +23,7 @@ UpdateModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ConvertModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
MergeModelResponse = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
ImportModelAttributes = Union[tuple(OPENAPI_MODEL_CONFIGS)]
|
||||
|
||||
class ModelsList(BaseModel):
|
||||
models: list[Union[tuple(OPENAPI_MODEL_CONFIGS)]]
|
||||
@ -79,7 +80,7 @@ async def update_model(
|
||||
return model_response
|
||||
|
||||
@models_router.post(
|
||||
"/",
|
||||
"/import",
|
||||
operation_id="import_model",
|
||||
responses= {
|
||||
201: {"description" : "The model imported successfully"},
|
||||
@ -95,7 +96,7 @@ async def import_model(
|
||||
prediction_type: Optional[Literal['v_prediction','epsilon','sample']] = \
|
||||
Body(description='Prediction type for SDv2 checkpoint files', default="v_prediction"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using its local path, repo_id, or remote URL """
|
||||
""" Add a model using its local path, repo_id, or remote URL. Model characteristics will be probed and configured automatically """
|
||||
|
||||
items_to_import = {location}
|
||||
prediction_types = { x.value: x for x in SchedulerPredictionType }
|
||||
@ -127,7 +128,91 @@ async def import_model(
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.post(
|
||||
"/add",
|
||||
operation_id="add_model",
|
||||
responses= {
|
||||
201: {"description" : "The model added successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
424: {"description" : "The model appeared to add successfully, but could not be found in the model manager"},
|
||||
409: {"description" : "There is already a model corresponding to this path or repo_id"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def add_model(
|
||||
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
|
||||
) -> ImportModelResponse:
|
||||
""" Add a model using the configuration information appropriate for its type. Only local models can be added by path"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
ApiDependencies.invoker.services.model_manager.add_model(
|
||||
info.model_name,
|
||||
info.base_model,
|
||||
info.model_type,
|
||||
model_attributes = info.dict()
|
||||
)
|
||||
logger.info(f'Successfully added {info.model_name}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=info.model_name,
|
||||
base_model=info.base_model,
|
||||
model_type=info.model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.post(
|
||||
"/rename/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="rename_model",
|
||||
responses= {
|
||||
201: {"description" : "The model was renamed successfully"},
|
||||
404: {"description" : "The model could not be found"},
|
||||
409: {"description" : "There is already a model corresponding to the new name"},
|
||||
},
|
||||
status_code=201,
|
||||
response_model=ImportModelResponse
|
||||
)
|
||||
async def rename_model(
|
||||
base_model: BaseModelType = Path(description="Base model"),
|
||||
model_type: ModelType = Path(description="The type of model"),
|
||||
model_name: str = Path(description="current model name"),
|
||||
new_name: Optional[str] = Query(description="new model name", default=None),
|
||||
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
|
||||
) -> ImportModelResponse:
|
||||
""" Rename a model"""
|
||||
|
||||
logger = ApiDependencies.invoker.services.logger
|
||||
|
||||
try:
|
||||
result = ApiDependencies.invoker.services.model_manager.rename_model(
|
||||
base_model = base_model,
|
||||
model_type = model_type,
|
||||
model_name = model_name,
|
||||
new_name = new_name,
|
||||
new_base = new_base,
|
||||
)
|
||||
logger.debug(result)
|
||||
logger.info(f'Successfully renamed {model_name}=>{new_name}')
|
||||
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
|
||||
model_name=new_name or model_name,
|
||||
base_model=new_base or base_model,
|
||||
model_type=model_type
|
||||
)
|
||||
return parse_obj_as(ImportModelResponse, model_raw)
|
||||
except KeyError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except ValueError as e:
|
||||
logger.error(str(e))
|
||||
raise HTTPException(status_code=409, detail=str(e))
|
||||
|
||||
@models_router.delete(
|
||||
"/{base_model}/{model_type}/{model_name}",
|
||||
operation_id="del_model",
|
||||
|
Reference in New Issue
Block a user