mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
more refactoring; HF subfolders not working
This commit is contained in:
@ -222,7 +222,7 @@ def test_delete_register(
|
||||
store.get_model(key)
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://www.test.foo/download/test_embedding.safetensors"))
|
||||
|
||||
@ -253,7 +253,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config:
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo"))
|
||||
|
||||
@ -285,9 +285,8 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
# TODO: Test subfolder download
|
||||
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
@ -323,6 +322,7 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
||||
assert job.total_bytes == completed_events[0].payload["total_bytes"]
|
||||
assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"])
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))
|
||||
job = mm2_installer.import_model(source)
|
||||
@ -371,7 +371,7 @@ def test_other_error_during_install(
|
||||
},
|
||||
],
|
||||
)
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, model_params: Dict[str, str]):
|
||||
"""Test whether or not type is respected on configs when passed to heuristic import."""
|
||||
assert "name" in model_params and "type" in model_params
|
||||
@ -387,7 +387,7 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
||||
}
|
||||
assert "repo_id" in model_params
|
||||
install_job1 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config1)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=20)
|
||||
mm2_installer.wait_for_job(install_job1, timeout=10)
|
||||
if model_params["type"] != "embedding":
|
||||
assert install_job1.errored
|
||||
assert install_job1.error_type == "InvalidModelConfigException"
|
||||
@ -396,6 +396,6 @@ def test_heuristic_import_with_type(mm2_installer: ModelInstallServiceBase, mode
|
||||
assert install_job1.config_out if model_params["type"] == "embedding" else not install_job1.config_out
|
||||
|
||||
install_job2 = mm2_installer.heuristic_import(source=model_params["repo_id"], config=config2)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=20)
|
||||
mm2_installer.wait_for_job(install_job2, timeout=10)
|
||||
assert install_job2.complete
|
||||
assert install_job2.config_out if model_params["type"] == "embedding" else not install_job2.config_out
|
||||
|
Reference in New Issue
Block a user