prevent two models from being marked default in models.yaml

This commit is contained in:
Lincoln Stein 2022-11-11 04:41:02 +00:00
parent 8dc7f119e5
commit af040e97af
2 changed files with 7 additions and 4 deletions

View File

@ -109,10 +109,13 @@ class ModelCache(object):
Set the default model. The change will not take Set the default model. The change will not take
effect until you call model_cache.commit() effect until you call model_cache.commit()
''' '''
print(f'DEBUG: before set_default_model()\n{OmegaConf.to_yaml(self.config)}')
assert model_name in self.models,f"unknown model '{model_name}'" assert model_name in self.models,f"unknown model '{model_name}'"
for model in self.models: config = self.config
self.models[model].pop('default',None) for model in config:
self.models[model_name]['default'] = True config[model].pop('default',None)
config[model_name]['default'] = True
print(f'DEBUG: after set_default_model():\n{OmegaConf.to_yaml(self.config)}')
def list_models(self) -> dict: def list_models(self) -> dict:
''' '''

View File

@ -584,7 +584,7 @@ def write_config_file(conf_path, gen, model_name, new_config, clobber=False, mak
try: try:
print('>> Verifying that new model loads...') print('>> Verifying that new model loads...')
yaml_str = gen.model_cache.add_model(model_name, new_config, clobber) gen.model_cache.add_model(model_name, new_config, clobber)
assert gen.set_model(model_name) is not None, 'model failed to load' assert gen.set_model(model_name) is not None, 'model failed to load'
except AssertionError as e: except AssertionError as e:
print(f'** aborting **') print(f'** aborting **')