diff --git a/invokeai/app/api/routers/models.py b/invokeai/app/api/routers/models.py index c298114cbc..923a3767a3 100644 --- a/invokeai/app/api/routers/models.py +++ b/invokeai/app/api/routers/models.py @@ -13,8 +13,10 @@ from invokeai.backend import BaseModelType, ModelType from invokeai.backend.model_management.models import ( OPENAPI_MODEL_CONFIGS, SchedulerPredictionType, + ModelNotFoundException, ) from invokeai.backend.model_management import MergeInterpolationMethod + from ..dependencies import ApiDependencies models_router = APIRouter(prefix="/v1/models", tags=["models"]) @@ -46,8 +48,9 @@ async def list_models( "/{base_model}/{model_type}/{model_name}", operation_id="update_model", responses={200: {"description" : "The model was updated successfully"}, + 400: {"description" : "Bad request"}, 404: {"description" : "The model could not be found"}, - 400: {"description" : "Bad request"} + 409: {"description" : "There is already a model corresponding to the new name"}, }, status_code = 200, response_model = UpdateModelResponse, @@ -58,23 +61,43 @@ async def update_model( model_name: str = Path(description="model name"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), ) -> UpdateModelResponse: - """ Add Model """ + """ Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """ + logger = ApiDependencies.invoker.services.logger try: + # rename operation requested + if info.model_name != model_name or info.base_model != base_model: + result = ApiDependencies.invoker.services.model_manager.rename_model( + base_model = base_model, + model_type = model_type, + model_name = model_name, + new_name = info.model_name, + new_base = info.base_model, + ) + logger.debug(f'renaming result = {result}') + logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}') + model_name = info.model_name + base_model = info.base_model + ApiDependencies.invoker.services.model_manager.update_model( model_name=model_name, base_model=base_model, model_type=model_type, model_attributes=info.dict() ) + model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_name=model_name, base_model=base_model, model_type=model_type, ) model_response = parse_obj_as(UpdateModelResponse, model_raw) - except KeyError as e: + except ModelNotFoundException as e: raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: + logger.error(str(e)) + raise HTTPException(status_code=409, detail=str(e)) + except Exception as e: + logger.error(str(e)) raise HTTPException(status_code=400, detail=str(e)) return model_response @@ -121,7 +144,7 @@ async def import_model( ) return parse_obj_as(ImportModelResponse, model_raw) - except KeyError as e: + except ModelNotFoundException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: @@ -161,57 +184,13 @@ async def add_model( model_type=info.model_type ) return parse_obj_as(ImportModelResponse, model_raw) - except KeyError as e: + except ModelNotFoundException as e: logger.error(str(e)) raise HTTPException(status_code=404, detail=str(e)) except ValueError as e: logger.error(str(e)) raise HTTPException(status_code=409, detail=str(e)) -@models_router.post( - "/rename/{base_model}/{model_type}/{model_name}", - operation_id="rename_model", - responses= { - 201: {"description" : "The model was renamed successfully"}, - 404: {"description" : "The model could not be found"}, - 409: {"description" : "There is already a model corresponding to the new name"}, - }, - status_code=201, - response_model=ImportModelResponse -) -async def rename_model( - base_model: BaseModelType = Path(description="Base model"), - model_type: ModelType = Path(description="The type of model"), - model_name: str = Path(description="current model name"), - new_name: Optional[str] = Query(description="new model name", default=None), - new_base: Optional[BaseModelType] = Query(description="new model base", default=None), -) -> ImportModelResponse: - """ Rename a model""" - - logger = ApiDependencies.invoker.services.logger - - try: - result = ApiDependencies.invoker.services.model_manager.rename_model( - base_model = base_model, - model_type = model_type, - model_name = model_name, - new_name = new_name, - new_base = new_base, - ) - logger.debug(result) - logger.info(f'Successfully renamed {model_name}=>{new_name}') - model_raw = ApiDependencies.invoker.services.model_manager.list_model( - model_name=new_name or model_name, - base_model=new_base or base_model, - model_type=model_type - ) - return parse_obj_as(ImportModelResponse, model_raw) - except KeyError as e: - logger.error(str(e)) - raise HTTPException(status_code=404, detail=str(e)) - except ValueError as e: - logger.error(str(e)) - raise HTTPException(status_code=409, detail=str(e)) @models_router.delete( "/{base_model}/{model_type}/{model_name}", @@ -238,9 +217,9 @@ async def delete_model( ) logger.info(f"Deleted model: {model_name}") return Response(status_code=204) - except KeyError: - logger.error(f"Model not found: {model_name}") - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + except ModelNotFoundException as e: + logger.error(str(e)) + raise HTTPException(status_code=404, detail=str(e)) @models_router.put( "/convert/{base_model}/{model_type}/{model_name}", @@ -273,8 +252,8 @@ async def convert_model( base_model = base_model, model_type = model_type) response = parse_obj_as(ConvertModelResponse, model_raw) - except KeyError: - raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") + except ModelNotFoundException as e: + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response @@ -364,8 +343,55 @@ async def merge_models( model_type = ModelType.Main, ) response = parse_obj_as(ConvertModelResponse, model_raw) - except KeyError: + except ModelNotFoundException: raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return response + +# The rename operation is now supported by update_model and no longer needs to be +# a standalone route. +# @models_router.post( +# "/rename/{base_model}/{model_type}/{model_name}", +# operation_id="rename_model", +# responses= { +# 201: {"description" : "The model was renamed successfully"}, +# 404: {"description" : "The model could not be found"}, +# 409: {"description" : "There is already a model corresponding to the new name"}, +# }, +# status_code=201, +# response_model=ImportModelResponse +# ) +# async def rename_model( +# base_model: BaseModelType = Path(description="Base model"), +# model_type: ModelType = Path(description="The type of model"), +# model_name: str = Path(description="current model name"), +# new_name: Optional[str] = Query(description="new model name", default=None), +# new_base: Optional[BaseModelType] = Query(description="new model base", default=None), +# ) -> ImportModelResponse: +# """ Rename a model""" + +# logger = ApiDependencies.invoker.services.logger + +# try: +# result = ApiDependencies.invoker.services.model_manager.rename_model( +# base_model = base_model, +# model_type = model_type, +# model_name = model_name, +# new_name = new_name, +# new_base = new_base, +# ) +# logger.debug(result) +# logger.info(f'Successfully renamed {model_name}=>{new_name}') +# model_raw = ApiDependencies.invoker.services.model_manager.list_model( +# model_name=new_name or model_name, +# base_model=new_base or base_model, +# model_type=model_type +# ) +# return parse_obj_as(ImportModelResponse, model_raw) +# except ModelNotFoundException as e: +# logger.error(str(e)) +# raise HTTPException(status_code=404, detail=str(e)) +# except ValueError as e: +# logger.error(str(e)) +# raise HTTPException(status_code=409, detail=str(e)) diff --git a/invokeai/app/services/model_manager_service.py b/invokeai/app/services/model_manager_service.py index 67db5c9478..7dba1dff06 100644 --- a/invokeai/app/services/model_manager_service.py +++ b/invokeai/app/services/model_manager_service.py @@ -18,6 +18,7 @@ from invokeai.backend.model_management import ( SchedulerPredictionType, ModelMerger, MergeInterpolationMethod, + ModelNotFoundException, ) from invokeai.backend.model_management.model_search import FindModels @@ -145,7 +146,7 @@ class ModelManagerServiceBase(ABC): ) -> AddModelResult: """ Update the named model with a dictionary of attributes. Will fail with a - KeyErrorException if the name does not already exist. + ModelNotFoundException if the name does not already exist. On a successful update, the config will be changed in memory. Will fail with an assertion error if provided attributes are incorrect or @@ -451,14 +452,14 @@ class ModelManagerService(ModelManagerServiceBase): ) -> AddModelResult: """ Update the named model with a dictionary of attributes. Will fail with a - KeyError exception if the name does not already exist. + ModelNotFoundException exception if the name does not already exist. On a successful update, the config will be changed in memory. Will fail with an assertion error if provided attributes are incorrect or the model name is missing. Call commit() to write changes to disk. """ self.logger.debug(f'update model {model_name}') if not self.model_exists(model_name, base_model, model_type): - raise KeyError(f"Unknown model {model_name}") + raise ModelNotFoundException(f"Unknown model {model_name}") return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) def del_model( diff --git a/invokeai/backend/model_management/__init__.py b/invokeai/backend/model_management/__init__.py index e31085acef..086a8721a1 100644 --- a/invokeai/backend/model_management/__init__.py +++ b/invokeai/backend/model_management/__init__.py @@ -3,6 +3,6 @@ Initialization file for invokeai.backend.model_management """ from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_cache import ModelCache -from .models import BaseModelType, ModelType, SubModelType, ModelVariantType +from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException from .model_merge import ModelMerger, MergeInterpolationMethod diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 55f6de9b5b..e1c0a2a85b 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -552,7 +552,7 @@ class ModelManager(object): model_config = self.models.get(model_key) if not model_config: self.logger.error(f'Unknown model {model_name}') - raise KeyError(f'Unknown model {model_name}') + raise ModelNotFoundException(f'Unknown model {model_name}') cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) if base_model is not None and cur_base_model != base_model: @@ -596,7 +596,7 @@ class ModelManager(object): model_cfg = self.models.pop(model_key, None) if model_cfg is None: - raise KeyError(f"Unknown model {model_key}") + raise ModelNotFoundException(f"Unknown model {model_key}") # note: it not garantie to release memory(model can has other references) cache_ids = self.cache_keys.pop(model_key, []) @@ -689,7 +689,7 @@ class ModelManager(object): model_key = self.create_key(model_name, base_model, model_type) model_cfg = self.models.get(model_key, None) if not model_cfg: - raise KeyError(f"Unknown model: {model_key}") + raise ModelNotFoundException(f"Unknown model: {model_key}") old_path = self.app_config.root_path / model_cfg.path new_name = new_name or model_name @@ -965,7 +965,7 @@ class ModelManager(object): that model. May return the following exceptions: - - KeyError - one or more of the items to import is not a valid path, repo_id or URL + - ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL - ValueError - a corresponding model already exists ''' # avoid circular import here