From 911a24479b15e4858beda43817d91b22e6a32b51 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 16 May 2024 07:18:33 -0400 Subject: [PATCH] add tests for model install file size reporting --- .../app/services/model_install/model_install_default.py | 9 ++++++--- tests/app/services/model_install/test_model_install.py | 5 +++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index 2ad321c260..1d77b2c6e1 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -855,8 +855,8 @@ class ModelInstallService(ModelInstallServiceBase): str(job.source), local_path=job.local_path.as_posix(), parts=parts, - bytes=job.bytes, - total_bytes=job.total_bytes, + bytes=sum(x["bytes"] for x in parts), + total_bytes=sum(x["total_bytes"] for x in parts), id=job.id, ) @@ -875,7 +875,10 @@ class ModelInstallService(ModelInstallServiceBase): assert job.local_path is not None assert job.config_out is not None key = job.config_out.key - self._event_bus.emit_model_install_completed(str(job.source), key, id=job.id) + self._event_bus.emit_model_install_completed(source=str(job.source), + key=key, + id=job.id, + total_bytes=job.bytes) def _signal_job_errored(self, job: ModelInstallJob) -> None: self._logger.error(f"Model install error: {job.source}\n{job.error_type}: {job.error}") diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 31d09d1029..f73b827534 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -317,6 +317,11 @@ def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_con "model_install_completed", } + completed_events = [x for x in bus.events if x.event_name == "model_install_completed"] + downloading_events = [x for x in bus.events if x.event_name == "model_install_downloading"] + assert completed_events[0].payload["total_bytes"] == downloading_events[-1].payload["bytes"] + assert job.total_bytes == completed_events[0].payload["total_bytes"] + assert job.total_bytes == sum(x["total_bytes"] for x in downloading_events[-1].payload["parts"]) def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://test.com/missing_model.safetensors"))