Simplify logic for determining model type in probe

This commit is contained in:
Brandon Rising 2024-03-17 22:42:21 -04:00 committed by Brandon
parent 39f62ac63c
commit 28bfc1c935
2 changed files with 17 additions and 6 deletions

View File

@ -132,8 +132,7 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None model_info = None
model_type = fields["type"] if "type" in fields else None model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
model_type = ModelType(model_type) if isinstance(model_type, str) else model_type
if not model_type: if not model_type:
if format_type is ModelFormat.Diffusers: if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path) model_type = cls.get_model_type_from_folder(model_path)

View File

@ -286,15 +286,27 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
) )
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]): def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import.""" """Test whether or not type is respected on configs when passed to heuristic import."""
config: Dict[str, Any] = { assert "name" in model_params and "type" in model_params
config1: Dict[str, Any] = {
"name": f"{model_params['name']}_1",
"type": model_params["type"], "type": model_params["type"],
} }
config2: Dict[str, Any] = {
"name": f"{model_params['name']}_2",
"type": ModelType(model_params["type"]),
}
try: try:
assert "repo_id" in model_params assert "repo_id" in model_params
install_job = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config) install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
while not install_job.in_terminal_state: while not install_job1.in_terminal_state:
sleep(0.01) sleep(0.01)
assert install_job.config_out if model_params["type"] == "embedding" else not install_job.config_out assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
while not install_job2.in_terminal_state:
sleep(0.01)
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
except InvalidModelConfigException: except InvalidModelConfigException:
assert model_params["type"] != "embedding" assert model_params["type"] != "embedding"