add renaming capabilities to model update API route (#3793)

This PR allows the `update_model` API call to change the model's name
and/or base type as well. The `rename_model` call has accordingly been
retired.
This commit is contained in:
blessedcoolant 2023-07-17 08:52:59 +12:00 committed by GitHub
commit 32994a261a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 90 additions and 63 deletions

View File

@ -13,8 +13,10 @@ from invokeai.backend import BaseModelType, ModelType
from invokeai.backend.model_management.models import ( from invokeai.backend.model_management.models import (
OPENAPI_MODEL_CONFIGS, OPENAPI_MODEL_CONFIGS,
SchedulerPredictionType, SchedulerPredictionType,
ModelNotFoundException,
) )
from invokeai.backend.model_management import MergeInterpolationMethod from invokeai.backend.model_management import MergeInterpolationMethod
from ..dependencies import ApiDependencies from ..dependencies import ApiDependencies
models_router = APIRouter(prefix="/v1/models", tags=["models"]) models_router = APIRouter(prefix="/v1/models", tags=["models"])
@ -46,8 +48,9 @@ async def list_models(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
operation_id="update_model", operation_id="update_model",
responses={200: {"description" : "The model was updated successfully"}, responses={200: {"description" : "The model was updated successfully"},
400: {"description" : "Bad request"},
404: {"description" : "The model could not be found"}, 404: {"description" : "The model could not be found"},
400: {"description" : "Bad request"} 409: {"description" : "There is already a model corresponding to the new name"},
}, },
status_code = 200, status_code = 200,
response_model = UpdateModelResponse, response_model = UpdateModelResponse,
@ -58,23 +61,43 @@ async def update_model(
model_name: str = Path(description="model name"), model_name: str = Path(description="model name"),
info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"), info: Union[tuple(OPENAPI_MODEL_CONFIGS)] = Body(description="Model configuration"),
) -> UpdateModelResponse: ) -> UpdateModelResponse:
""" Add Model """ """ Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies.invoker.services.logger
try: try:
# rename operation requested
if info.model_name != model_name or info.base_model != base_model:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = info.model_name,
new_base = info.base_model,
)
logger.debug(f'renaming result = {result}')
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
model_name = info.model_name
base_model = info.base_model
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
model_attributes=info.dict() model_attributes=info.dict()
) )
model_raw = ApiDependencies.invoker.services.model_manager.list_model( model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name, model_name=model_name,
base_model=base_model, base_model=base_model,
model_type=model_type, model_type=model_type,
) )
model_response = parse_obj_as(UpdateModelResponse, model_raw) model_response = parse_obj_as(UpdateModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
except Exception as e:
logger.error(str(e))
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return model_response return model_response
@ -121,7 +144,7 @@ async def import_model(
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
@ -161,57 +184,13 @@ async def add_model(
model_type=info.model_type model_type=info.model_type
) )
return parse_obj_as(ImportModelResponse, model_raw) return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e: except ModelNotFoundException as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e)) raise HTTPException(status_code=404, detail=str(e))
except ValueError as e: except ValueError as e:
logger.error(str(e)) logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e)) raise HTTPException(status_code=409, detail=str(e))
@models_router.post(
"/rename/{base_model}/{model_type}/{model_name}",
operation_id="rename_model",
responses= {
201: {"description" : "The model was renamed successfully"},
404: {"description" : "The model could not be found"},
409: {"description" : "There is already a model corresponding to the new name"},
},
status_code=201,
response_model=ImportModelResponse
)
async def rename_model(
base_model: BaseModelType = Path(description="Base model"),
model_type: ModelType = Path(description="The type of model"),
model_name: str = Path(description="current model name"),
new_name: Optional[str] = Query(description="new model name", default=None),
new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
) -> ImportModelResponse:
""" Rename a model"""
logger = ApiDependencies.invoker.services.logger
try:
result = ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model,
model_type = model_type,
model_name = model_name,
new_name = new_name,
new_base = new_base,
)
logger.debug(result)
logger.info(f'Successfully renamed {model_name}=>{new_name}')
model_raw = ApiDependencies.invoker.services.model_manager.list_model(
model_name=new_name or model_name,
base_model=new_base or base_model,
model_type=model_type
)
return parse_obj_as(ImportModelResponse, model_raw)
except KeyError as e:
logger.error(str(e))
raise HTTPException(status_code=404, detail=str(e))
except ValueError as e:
logger.error(str(e))
raise HTTPException(status_code=409, detail=str(e))
@models_router.delete( @models_router.delete(
"/{base_model}/{model_type}/{model_name}", "/{base_model}/{model_type}/{model_name}",
@ -238,9 +217,9 @@ async def delete_model(
) )
logger.info(f"Deleted model: {model_name}") logger.info(f"Deleted model: {model_name}")
return Response(status_code=204) return Response(status_code=204)
except KeyError: except ModelNotFoundException as e:
logger.error(f"Model not found: {model_name}") logger.error(str(e))
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=str(e))
@models_router.put( @models_router.put(
"/convert/{base_model}/{model_type}/{model_name}", "/convert/{base_model}/{model_type}/{model_name}",
@ -273,8 +252,8 @@ async def convert_model(
base_model = base_model, base_model = base_model,
model_type = model_type) model_type = model_type)
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except ModelNotFoundException as e:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found") raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found: {str(e)}")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
@ -364,8 +343,55 @@ async def merge_models(
model_type = ModelType.Main, model_type = ModelType.Main,
) )
response = parse_obj_as(ConvertModelResponse, model_raw) response = parse_obj_as(ConvertModelResponse, model_raw)
except KeyError: except ModelNotFoundException:
raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found") raise HTTPException(status_code=404, detail=f"One or more of the models '{model_names}' not found")
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
return response return response
# The rename operation is now supported by update_model and no longer needs to be
# a standalone route.
# @models_router.post(
# "/rename/{base_model}/{model_type}/{model_name}",
# operation_id="rename_model",
# responses= {
# 201: {"description" : "The model was renamed successfully"},
# 404: {"description" : "The model could not be found"},
# 409: {"description" : "There is already a model corresponding to the new name"},
# },
# status_code=201,
# response_model=ImportModelResponse
# )
# async def rename_model(
# base_model: BaseModelType = Path(description="Base model"),
# model_type: ModelType = Path(description="The type of model"),
# model_name: str = Path(description="current model name"),
# new_name: Optional[str] = Query(description="new model name", default=None),
# new_base: Optional[BaseModelType] = Query(description="new model base", default=None),
# ) -> ImportModelResponse:
# """ Rename a model"""
# logger = ApiDependencies.invoker.services.logger
# try:
# result = ApiDependencies.invoker.services.model_manager.rename_model(
# base_model = base_model,
# model_type = model_type,
# model_name = model_name,
# new_name = new_name,
# new_base = new_base,
# )
# logger.debug(result)
# logger.info(f'Successfully renamed {model_name}=>{new_name}')
# model_raw = ApiDependencies.invoker.services.model_manager.list_model(
# model_name=new_name or model_name,
# base_model=new_base or base_model,
# model_type=model_type
# )
# return parse_obj_as(ImportModelResponse, model_raw)
# except ModelNotFoundException as e:
# logger.error(str(e))
# raise HTTPException(status_code=404, detail=str(e))
# except ValueError as e:
# logger.error(str(e))
# raise HTTPException(status_code=409, detail=str(e))

View File

@ -18,6 +18,7 @@ from invokeai.backend.model_management import (
SchedulerPredictionType, SchedulerPredictionType,
ModelMerger, ModelMerger,
MergeInterpolationMethod, MergeInterpolationMethod,
ModelNotFoundException,
) )
from invokeai.backend.model_management.model_search import FindModels from invokeai.backend.model_management.model_search import FindModels
@ -145,7 +146,7 @@ class ModelManagerServiceBase(ABC):
) -> AddModelResult: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with a Update the named model with a dictionary of attributes. Will fail with a
KeyErrorException if the name does not already exist. ModelNotFoundException if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
@ -451,14 +452,14 @@ class ModelManagerService(ModelManagerServiceBase):
) -> AddModelResult: ) -> AddModelResult:
""" """
Update the named model with a dictionary of attributes. Will fail with a Update the named model with a dictionary of attributes. Will fail with a
KeyError exception if the name does not already exist. ModelNotFoundException exception if the name does not already exist.
On a successful update, the config will be changed in memory. Will fail On a successful update, the config will be changed in memory. Will fail
with an assertion error if provided attributes are incorrect or with an assertion error if provided attributes are incorrect or
the model name is missing. Call commit() to write changes to disk. the model name is missing. Call commit() to write changes to disk.
""" """
self.logger.debug(f'update model {model_name}') self.logger.debug(f'update model {model_name}')
if not self.model_exists(model_name, base_model, model_type): if not self.model_exists(model_name, base_model, model_type):
raise KeyError(f"Unknown model {model_name}") raise ModelNotFoundException(f"Unknown model {model_name}")
return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True) return self.add_model(model_name, base_model, model_type, model_attributes, clobber=True)
def del_model( def del_model(

View File

@ -3,6 +3,6 @@ Initialization file for invokeai.backend.model_management
""" """
from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType from .model_manager import ModelManager, ModelInfo, AddModelResult, SchedulerPredictionType
from .model_cache import ModelCache from .model_cache import ModelCache
from .models import BaseModelType, ModelType, SubModelType, ModelVariantType from .models import BaseModelType, ModelType, SubModelType, ModelVariantType, ModelNotFoundException
from .model_merge import ModelMerger, MergeInterpolationMethod from .model_merge import ModelMerger, MergeInterpolationMethod

View File

@ -552,7 +552,7 @@ class ModelManager(object):
model_config = self.models.get(model_key) model_config = self.models.get(model_key)
if not model_config: if not model_config:
self.logger.error(f'Unknown model {model_name}') self.logger.error(f'Unknown model {model_name}')
raise KeyError(f'Unknown model {model_name}') raise ModelNotFoundException(f'Unknown model {model_name}')
cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key) cur_model_name, cur_base_model, cur_model_type = self.parse_key(model_key)
if base_model is not None and cur_base_model != base_model: if base_model is not None and cur_base_model != base_model:
@ -596,7 +596,7 @@ class ModelManager(object):
model_cfg = self.models.pop(model_key, None) model_cfg = self.models.pop(model_key, None)
if model_cfg is None: if model_cfg is None:
raise KeyError(f"Unknown model {model_key}") raise ModelNotFoundException(f"Unknown model {model_key}")
# note: it not garantie to release memory(model can has other references) # note: it not garantie to release memory(model can has other references)
cache_ids = self.cache_keys.pop(model_key, []) cache_ids = self.cache_keys.pop(model_key, [])
@ -689,7 +689,7 @@ class ModelManager(object):
model_key = self.create_key(model_name, base_model, model_type) model_key = self.create_key(model_name, base_model, model_type)
model_cfg = self.models.get(model_key, None) model_cfg = self.models.get(model_key, None)
if not model_cfg: if not model_cfg:
raise KeyError(f"Unknown model: {model_key}") raise ModelNotFoundException(f"Unknown model: {model_key}")
old_path = self.app_config.root_path / model_cfg.path old_path = self.app_config.root_path / model_cfg.path
new_name = new_name or model_name new_name = new_name or model_name
@ -965,7 +965,7 @@ class ModelManager(object):
that model. that model.
May return the following exceptions: May return the following exceptions:
- KeyError - one or more of the items to import is not a valid path, repo_id or URL - ModelNotFoundException - one or more of the items to import is not a valid path, repo_id or URL
- ValueError - a corresponding model already exists - ValueError - a corresponding model already exists
''' '''
# avoid circular import here # avoid circular import here