From 28bfc1c9351c0a3d48215b9887f11417f1e254f4 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Sun, 17 Mar 2024 22:42:21 -0400 Subject: [PATCH] Simplify logic for determining model type in probe --- invokeai/backend/model_manager/probe.py | 3 +-- .../model_install/test_model_install.py | 20 +++++++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index 2e433049ff..e814644a45 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -132,8 +132,7 @@ class ModelProbe(object): format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint model_info = None - model_type = fields["type"] if "type" in fields else None - model_type = ModelType(model_type) if isinstance(model_type, str) else model_type + model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None if not model_type: if format_type is ModelFormat.Diffusers: model_type = cls.get_model_type_from_folder(model_path) diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index bb507fb12e..7a3e705f16 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -286,15 +286,27 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In ) def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): """Test whether or not type is respected on configs when passed to heuristic import.""" - config: Dict[str, Any] = { + assert "name" in model_params and "type" in model_params + config1: Dict[str, Any] = { + "name": f"{model_params['name']}_1", "type": model_params["type"], } + config2: Dict[str, Any] = { + "name": f"{model_params['name']}_2", + "type": ModelType(model_params["type"]), + } try: assert "repo_id" in model_params - install_job = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config) + install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1) - while not install_job.in_terminal_state: + while not install_job1.in_terminal_state: sleep(0.01) - assert install_job.config_out if model_params["type"] == "embedding" else not install_job.config_out + assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out + + install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2) + + while not install_job2.in_terminal_state: + sleep(0.01) + assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out except InvalidModelConfigException: assert model_params["type"] != "embedding"