Simplify logic for determining model type in probe

This commit is contained in:
Brandon Rising 2024-03-17 22:42:21 -04:00 committed by Brandon
parent 39f62ac63c
commit 28bfc1c935
2 changed files with 17 additions and 6 deletions

View File

@ -132,8 +132,7 @@ class ModelProbe(object):
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
model_info = None
model_type = fields["type"] if "type" in fields else None
model_type = ModelType(model_type) if isinstance(model_type, str) else model_type
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
if not model_type:
if format_type is ModelFormat.Diffusers:
model_type = cls.get_model_type_from_folder(model_path)

View File

@ -286,15 +286,27 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
)
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
"""Test whether or not type is respected on configs when passed to heuristic import."""
config: Dict[str, Any] = {
assert "name" in model_params and "type" in model_params
config1: Dict[str, Any] = {
"name": f"{model_params['name']}_1",
"type": model_params["type"],
}
config2: Dict[str, Any] = {
"name": f"{model_params['name']}_2",
"type": ModelType(model_params["type"]),
}
try:
assert "repo_id" in model_params
install_job = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config)
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
while not install_job.in_terminal_state:
while not install_job1.in_terminal_state:
sleep(0.01)
assert install_job.config_out if model_params["type"] == "embedding" else not install_job.config_out
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
while not install_job2.in_terminal_state:
sleep(0.01)
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
except InvalidModelConfigException:
assert model_params["type"] != "embedding"