From ead1748c544696faf49c8e0d73366feaf34463a9 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Tue, 28 May 2024 19:30:42 -0400 Subject: [PATCH] issue a download progress event when install download starts --- .../services/model_install/model_install_default.py | 8 +++++--- .../app/services/model_install/test_model_install.py | 11 ++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index cd4bce4108..3b8a408e97 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -815,10 +815,13 @@ class ModelInstallService(ModelInstallServiceBase): if install_job := self._download_cache.get(download_job.id, None): install_job.status = InstallStatus.DOWNLOADING - assert download_job.download_path if install_job.local_path == install_job._install_tmpdir: # first time + assert download_job.download_path install_job.local_path = download_job.download_path - install_job.total_bytes = download_job.total_bytes + install_job.download_parts = download_job.download_parts + install_job.bytes = sum(x.bytes for x in download_job.download_parts) + install_job.total_bytes = download_job.total_bytes + self._signal_job_downloading(install_job) def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: @@ -829,7 +832,6 @@ class ModelInstallService(ModelInstallServiceBase): # update sizes install_job.bytes = sum(x.bytes for x in download_job.download_parts) install_job.total_bytes = sum(x.total_bytes for x in download_job.download_parts) - install_job.download_parts = download_job.download_parts self._signal_job_downloading(install_job) def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index 5c9f908ccc..9602a79a27 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -251,11 +251,12 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: model_record = store.get_model(key) assert (mm2_app_config.models_path / model_record.path).exists() - assert len(bus.events) == 4 - assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) - assert isinstance(bus.events[1], ModelInstallDownloadsCompleteEvent) - assert isinstance(bus.events[2], ModelInstallStartedEvent) - assert isinstance(bus.events[3], ModelInstallCompleteEvent) + assert len(bus.events) == 5 + assert isinstance(bus.events[0], ModelInstallDownloadProgressEvent) # download starts + assert isinstance(bus.events[1], ModelInstallDownloadProgressEvent) # download progresses + assert isinstance(bus.events[2], ModelInstallDownloadsCompleteEvent) # download completed + assert isinstance(bus.events[3], ModelInstallStartedEvent) # install started + assert isinstance(bus.events[4], ModelInstallCompleteEvent) # install completed @pytest.mark.timeout(timeout=10, method="thread")