refactor(ModelManager): factor out get_model_path

This commit is contained in:
Kevin Turner 2023-07-28 22:01:28 -07:00
parent b163ae6a4d
commit bc9a5038fd

View File

@ -451,37 +451,33 @@ class ModelManager(object):
:param model_name: symbolic name of the model in models.yaml
:param model_type: ModelType enum indicating the type of model to return
:param base_model: BaseModelType enum indicating the base model used by this model
:param submode_typel: an ModelType enum indicating the portion of
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
model_class = self._get_implementation(base_model, model_type)
model_key = self.create_key(model_name, base_model, model_type)
if not self.model_exists(model_name, base_model, model_type, rescan=True):
raise ModelNotFoundException(f"Model not found - {model_key}")
model_config = self._get_model_config(base_model, model_name, model_type)
model_path = self.app_config.root_path / model_config.path
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(base_model, model_type)
if not model_path.exists():
if model_class.save_to_config:
self.models[model_key].error = ModelError.NotFound
raise Exception(f'Files for model "{model_key}" not found')
raise Exception(f'Files for model "{model_key}" not found at {model_path}')
else:
self.models.pop(model_key, None)
raise ModelNotFoundException(f"Model not found - {model_key}")
# vae/movq override
# TODO:
if submodel_type is not None and hasattr(model_config, submodel_type):
override_path = getattr(model_config, submodel_type)
if override_path:
model_path = self.app_config.root_path / override_path
model_type = submodel_type
submodel_type = None
model_class = self._get_implementation(base_model, model_type)
# TODO: path
# TODO: is it accurate to use path as id
dst_convert_path = self._get_model_cache_path(model_path)
@ -518,6 +514,20 @@ class ModelManager(object):
_cache=self.cache,
)
def _get_model_path(self, model_config: ModelConfigBase, submodel_type: SubModelType = None) -> (Path, bool):
model_path = model_config.path
is_submodel_override = False
# Does the config explicitly override the submodel?
if submodel_type is not None and hasattr(model_config, submodel_type):
submodel_path = getattr(model_config, submodel_type)
if submodel_path is not None:
model_path = getattr(model_config, submodel_type)
is_submodel_override = True
model_path = self.app_config.root_path / model_path
return model_path, is_submodel_override
def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase:
model_key = self.create_key(model_name, base_model, model_type)
try: