keep model path consistent with model manager key in model update api

This commit is contained in:
Lincoln Stein 2023-07-17 10:00:28 -04:00
parent 0ea8d3c30c
commit 08854b6d68

View File

@ -63,20 +63,35 @@ async def update_model(
) -> UpdateModelResponse: ) -> UpdateModelResponse:
""" Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """ """ Update model contents with a new config. If the model name or base fields are changed, then the model is renamed. """
logger = ApiDependencies.invoker.services.logger logger = ApiDependencies.invoker.services.logger
try: try:
previous_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
# rename operation requested # rename operation requested
if info.model_name != model_name or info.base_model != base_model: if info.model_name != model_name or info.base_model != base_model:
result = ApiDependencies.invoker.services.model_manager.rename_model( ApiDependencies.invoker.services.model_manager.rename_model(
base_model = base_model, base_model = base_model,
model_type = model_type, model_type = model_type,
model_name = model_name, model_name = model_name,
new_name = info.model_name, new_name = info.model_name,
new_base = info.base_model, new_base = info.base_model,
) )
logger.debug(f'renaming result = {result}')
logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}') logger.info(f'Successfully renamed {base_model}/{model_name}=>{info.base_model}/{info.model_name}')
# update information to support an update of attributes
model_name = info.model_name model_name = info.model_name
base_model = info.base_model base_model = info.base_model
new_info = ApiDependencies.invoker.services.model_manager.list_model(
model_name=model_name,
base_model=base_model,
model_type=model_type,
)
if new_info.get('path') != previous_info.get('path'): # model manager moved model path during rename - don't overwrite it
info.path = new_info.get('path')
ApiDependencies.invoker.services.model_manager.update_model( ApiDependencies.invoker.services.model_manager.update_model(
model_name=model_name, model_name=model_name,