add renaming capabilities to model update API route

This commit is contained in:
Lincoln Stein 2023-07-16 14:17:05 -04:00
parent b56be07ab3
commit 6fbb5ce780
4 changed files with 90 additions and 63 deletions

View File

@ -13,8 +13,10 @@ from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType,
ModelNotFoundException,
)
from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -46,8 +48,9 @@ async def list_models(
"/{base_model}/{model_type}/{model_name}",
operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"},
400: {"description" : "Bad request"},
404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"}
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code = 200,
response_model = UpdateModelResponse,
@ -58,23 +61,43 @@ async def update_model(
model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse:
""" Add Model """
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies.invoker.services.logger
try:
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = info.model_name,
new_base = info.base_model,
)
logger.debug(f'renaming result = {result}')
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
model_name = info.model_name
base_model = info.base_model
ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
model_attributes=info.dict()
)
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e))
return model_response
@ -121,7 +144,7 @@ async def import_model(
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
@ -161,57 +184,13 @@ async def add_model(
model_type=info.model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete(
"/{base_model}/{model_type}/{model_name}",
@ -238,9 +217,9 @@ async def delete_model(
)
logger.info(f"Deleted model: {model_name}")
return Response(status_code=204)
except KeyError:
logger.error(f"Model not found: {model_name}")
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ModelNotFoundException as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
@models_router.put(
"/convert/{base_model}/{model_type}/{model_name}",
@ -273,8 +252,8 @@ async def convert_model(
base_model = base_model,
model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found")
except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
@ -364,8 +343,55 @@ async def merge_models(
model_type = ModelType.Main,
)
response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError:
except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return response
# The rename operation is now supported by update_model and no longer needs to be
# a standalone route.
# @models_router.post(
# "/rename/{base_model}/{model_type}/{model_name}",
# operation_id="rename_model",
# responses= {
# 201: {"description" : "The model was renamed successfully"},
# 404: {"description" : "The model could not be found"},
# 409: {"description" : "There is already a model corresponding to the new name"},
# },
# status_code=201,
# response_model=ImportModelResponse
# )
# async def rename_model(
# base_model: BaseModelType = Path(description="Base model"),
# model_type: ModelType = Path(description="The type of model"),
# model_name: str = Path(description="current model name"),
# new_name: Optional[str] = Query(description="new model name", default=None),
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
# ) -> ImportModelResponse:
# """ Rename a model"""
# logger = ApiDependencies.invoker.services.logger
# try:
# result = ApiDependencies.invoker.services.model_manager.rename_model(
# base_model = base_model,
# model_type = model_type,
# model_name = model_name,
# new_name = new_name,
# new_base = new_base,
# )
# logger.debug(result)
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
# model_name=new_name or model_name,
# base_model=new_base or base_model,
# model_type=model_type
# )
# return parse_obj_as(ImportModelResponse, model_raw)
# except ModelNotFoundException as e:
# logger.error(str(e))
# raise HTTPException(status_code=404, detail=str(e))
# except ValueError as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))

View File

@ -18,6 +18,7 @@ from invokeai.backend.model_management import (
SchedulerPredictionType,
ModelMerger,
MergeInterpolationMethod,
ModelNotFoundException,
)
from invokeai.backend.model_management.model_search import FindModels
@ -145,7 +146,7 @@ class ModelManagerServiceBase(ABC):
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist.
ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
@ -451,14 +452,14 @@ class ModelManagerService(ModelManagerServiceBase):
) -> AddModelResult:
"""
Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist.
ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk.
"""
self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}")
raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model(

View File

@ -3,6 +3,6 @@ Initialization file for invokeai.backend.model_management
"""
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -552,7 +552,7 @@ class ModelManager(object):
model_config = self.models.get(model_key)
if not model_config:
self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}')
raise ModelNotFoundException(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model:
@ -596,7 +596,7 @@ class ModelManager(object):
model_cfg = self.models.pop(model_key, None)
if model_cfg is None:
raise KeyError(f"Unknown model {model_key}")
raise ModelNotFoundException(f"Unknown model {model_key}")
# note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, [])
@ -689,7 +689,7 @@ class ModelManager(object):
model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None)
if not model_cfg:
raise KeyError(f"Unknown model: {model_key}")
raise ModelNotFoundException(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name
@ -965,7 +965,7 @@ class ModelManager(object):
that model.
May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL
- ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists
'''
# avoid circular import here