From e70bedba7d16b4c928220286ec09d38f058d742c Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Fri, 28 Jul 2023 21:03:27 -0700 Subject: [PATCH] refactor(ModelManager): factor out _get_implementation method --- .../backend/model_management/model_manager.py | 20 +++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/model_management/model_manager.py b/invokeai/backend/model_management/model_manager.py index 2a82061a97..fbabd2fece 100644 --- a/invokeai/backend/model_management/model_manager.py +++ b/invokeai/backend/model_management/model_manager.py @@ -258,7 +258,7 @@ from .models import ( ModelConfigBase, ModelNotFoundException, InvalidModelException, - DuplicateModelException, + DuplicateModelException, ModelBase, ) # We are only starting to number the config file with release 3. @@ -361,7 +361,7 @@ class ModelManager(object): if model_key.startswith("_"): continue model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) # alias for config file model_config["model_format"] = model_config.pop("format") self.models[model_key] = model_class.create_config(**model_config) @@ -446,7 +446,7 @@ class ModelManager(object): :param submode_typel: an ModelType enum indicating the portion of the model to retrieve (e.g. ModelType.Vae) """ - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) model_key = self.create_key(model_name, base_model, model_type) # if model not found try to find it (maybe file just pasted) @@ -475,7 +475,7 @@ class ModelManager(object): model_path = self.app_config.root_path / override_path model_type = submodel_type submodel_type = None - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) # TODO: path # TODO: is it accurate to use path as id @@ -513,6 +513,10 @@ class ModelManager(object): _cache=self.cache, ) + def _get_implementation(self, base_model: BaseModelType, model_type: ModelType) -> type[ModelBase]: + model_class = MODEL_CLASSES[base_model][model_type] + return model_class + def model_info( self, model_name: str, @@ -659,7 +663,7 @@ class ModelManager(object): if Path(path).is_relative_to(self.app_config.root_path): model_attributes["path"] = str(Path(path).relative_to(self.app_config.root_path)) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) model_config = model_class.create_config(**model_attributes) model_key = self.create_key(model_name, base_model, model_type) @@ -837,7 +841,7 @@ class ModelManager(object): for model_key, model_config in self.models.items(): model_name, base_model, model_type = self.parse_key(model_key) - model_class = MODEL_CLASSES[base_model][model_type] + model_class = self._get_implementation(base_model, model_type) if model_class.save_to_config: # TODO: or exclude_unset better fits here? data_to_save[model_key] = model_config.dict(exclude_defaults=True, exclude={"error"}) @@ -888,7 +892,7 @@ class ModelManager(object): model_name, cur_base_model, cur_model_type = self.parse_key(model_key) model_path = self.app_config.root_path.absolute() / model_config.path if not model_path.exists(): - model_class = MODEL_CLASSES[cur_base_model][cur_model_type] + model_class = self._get_implementation(cur_base_model, cur_model_type) if model_class.save_to_config: model_config.error = ModelError.NotFound self.models.pop(model_key, None) @@ -904,7 +908,7 @@ class ModelManager(object): for cur_model_type in ModelType: if model_type is not None and cur_model_type != model_type: continue - model_class = MODEL_CLASSES[cur_base_model][cur_model_type] + model_class = self._get_implementation(cur_base_model, cur_model_type) models_dir = self.app_config.models_path / cur_base_model.value / cur_model_type.value if not models_dir.exists():