mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Simplify logic for determining model type in probe
This commit is contained in:
parent
39f62ac63c
commit
28bfc1c935
@ -132,8 +132,7 @@ class ModelProbe(object):
|
|||||||
|
|
||||||
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
format_type = ModelFormat.Diffusers if model_path.is_dir() else ModelFormat.Checkpoint
|
||||||
model_info = None
|
model_info = None
|
||||||
model_type = fields["type"] if "type" in fields else None
|
model_type = ModelType(fields["type"]) if "type" in fields and fields["type"] else None
|
||||||
model_type = ModelType(model_type) if isinstance(model_type, str) else model_type
|
|
||||||
if not model_type:
|
if not model_type:
|
||||||
if format_type is ModelFormat.Diffusers:
|
if format_type is ModelFormat.Diffusers:
|
||||||
model_type = cls.get_model_type_from_folder(model_path)
|
model_type = cls.get_model_type_from_folder(model_path)
|
||||||
|
@ -286,15 +286,27 @@ def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: In
|
|||||||
)
|
)
|
||||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||||
config: Dict[str, Any] = {
|
assert "name" in model_params and "type" in model_params
|
||||||
|
config1: Dict[str, Any] = {
|
||||||
|
"name": f"{model_params['name']}_1",
|
||||||
"type": model_params["type"],
|
"type": model_params["type"],
|
||||||
}
|
}
|
||||||
|
config2: Dict[str, Any] = {
|
||||||
|
"name": f"{model_params['name']}_2",
|
||||||
|
"type": ModelType(model_params["type"]),
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
assert "repo_id" in model_params
|
assert "repo_id" in model_params
|
||||||
install_job = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config)
|
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||||
|
|
||||||
while not install_job.in_terminal_state:
|
while not install_job1.in_terminal_state:
|
||||||
sleep(0.01)
|
sleep(0.01)
|
||||||
assert install_job.config_out if model_params["type"] == "embedding" else not install_job.config_out
|
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||||
|
|
||||||
|
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||||
|
|
||||||
|
while not install_job2.in_terminal_state:
|
||||||
|
sleep(0.01)
|
||||||
|
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||||
except InvalidModelConfigException:
|
except InvalidModelConfigException:
|
||||||
assert model_params["type"] != "embedding"
|
assert model_params["type"] != "embedding"
|
||||||
|
Loading…
Reference in New Issue
Block a user