diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 51053f92cc..6c79b07959 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -451,37 +451,33 @@ class ModelManager(object): :param model_name: symbolic name of the model in models.yaml :param model_type: ModelType enum indicating the type of model to return :param base_model: BaseModelType enum indicating the base model used by this model - :param submode_typel: an ModelType enum indicating the portion of + :param submodel_type: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_class = self._get_implementation(base_model, model_type) model_key = self.create_key(model_name, base_model, model_type) if not self.model_exists(model_name, base_model, model_type, rescan=True): raise ModelNotFoundException(f"Model not found - {model_key}") model_config = self._get_model_config(base_model, model_name, model_type) - model_path = self.app_config.root_path / model_config.path + + model_path, is_submodel_override = self._get_model_path(model_config, submodel_type) + + if is_submodel_override: + model_type = submodel_type + submodel_type = None + + model_class = self._get_implementation(base_model, model_type) if not model_path.exists(): if model_class.save_to_config: self.models[model_key].error = ModelError.NotFound - raise Exception(f'Files for model "{model_key}" not found') + raise Exception(f'Files for model "{model_key}" not found at {model_path}') else: self.models.pop(model_key, None) raise ModelNotFoundException(f"Model not found - {model_key}") - # vae/movq override - # TODO: - if submodel_type is not None and hasattr(model_config, submodel_type): - override_path = getattr(model_config, submodel_type) - if override_path: - model_path = self.app_config.root_path / override_path - model_type = submodel_type - submodel_type = None - model_class = self._get_implementation(base_model, model_type) - # TODO: path # TODO: is it accurate to use path as id dst_convert_path = self._get_model_cache_path(model_path) @@ -518,6 +514,20 @@ class ModelManager(object): _cache=self.cache, ) + def _get_model_path(self, model_config: ModelConfigBase, submodel_type: SubModelType = None) -> (Path, bool): + model_path = model_config.path + is_submodel_override = False + + # Does the config explicitly override the submodel? + if submodel_type is not None and hasattr(model_config, submodel_type): + submodel_path = getattr(model_config, submodel_type) + if submodel_path is not None: + model_path = getattr(model_config, submodel_type) + is_submodel_override = True + + model_path = self.app_config.root_path / model_path + return model_path, is_submodel_override + def _get_model_config(self, base_model, model_name, model_type) -> ModelConfigBase: model_key = self.create_key(model_name, base_model, model_type) try: