refactor(ModelManager): factor out get_model_config

This commit is contained in:
Kevin Turner 2023-07-28 21:30:20 -07:00
parent dca685ac25
commit b163ae6a4d

View File

@ -460,7 +460,7 @@ class ModelManager(object):
if not self.model_exists(model_name, base_model, model_type, rescan=True): if not self.model_exists(model_name, base_model, model_type, rescan=True):
raise ModelNotFoundException(f"Model not found - {model_key}") raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self.models[model_key] model_config = self._get_model_config(base_model, model_name, model_type)
model_path = self.app_config.root_path / model_config.path model_path = self.app_config.root_path / model_config.path
if not model_path.exists(): if not model_path.exists():
@ -518,6 +518,14 @@ class ModelManager(object):
_cache=self.cache, _cache=self.cache,
) )
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
model_key = self.create_key(model_name, base_model, model_type)
try:
model_config = self.models[model_key]
except KeyError:
raise ModelNotFoundException(f"Model not found - {model_key}")
return model_config
def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]:
model_class = MODEL_CLASSES[base_model][model_type] model_class = MODEL_CLASSES[base_model][model_type]
return model_class return model_class