From 813a086cfe9fd73f69e4b62eb471d6bec29a5dfb Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 29 Feb 2024 15:08:10 -0500 Subject: [PATCH] fix race condition between downloading last file and starting install --- .../app/services/download/download_default.py | 21 ++----------------- .../model_install/model_install_base.py | 6 ++++++ .../model_install/model_install_default.py | 19 ++++------------- pyproject.toml | 4 ++-- .../services/download/test_download_queue.py | 1 + .../model_install/test_model_install.py | 1 + .../model_manager/model_manager_fixtures.py | 2 ++ 7 files changed, 18 insertions(+), 36 deletions(-) diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index c336ff4c8c..843351a259 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -221,29 +221,21 @@ class DownloadQueueService(DownloadQueueServiceBase): except Empty: continue try: - print(f'DEBUG: job [{job.id}] started', flush=True) job.job_started = get_iso_timestamp() self._do_download(job) - print(f'DEBUG: job [{job.id}] download completed', flush=True) self._signal_job_complete(job) - print(f'DEBUG: job [{job.id}] signaled completion', flush=True) except (OSError, HTTPError) as excp: job.error_type = excp.__class__.__name__ + f"({str(excp)})" job.error = traceback.format_exc() self._signal_job_error(job, excp) - print(f'DEBUG: job [{job.id}] signaled error', flush=True) except DownloadJobCancelledException: self._signal_job_cancelled(job) self._cleanup_cancelled_job(job) - print(f'DEBUG: job [{job.id}] signaled cancelled', flush=True) - + finally: - print(f'DEBUG: job [{job.id}] signalling completion', flush=True) job.job_ended = get_iso_timestamp() self._job_completed_event.set() # signal a change to terminal state - print(f'DEBUG: job [{job.id}] set job completion event', flush=True) self._queue.task_done() - print(f'DEBUG: job [{job.id}] done', flush=True) self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") def _do_download(self, job: DownloadJob) -> None: @@ -251,8 +243,7 @@ class DownloadQueueService(DownloadQueueServiceBase): url = job.source header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} open_mode = "wb" - print(f'DEBUG: In _do_download [0]', flush=True) - + # Make a streaming request. This will retrieve headers including # content-length and content-disposition, but not fetch any content itself resp = self._requests.get(str(url), headers=header, stream=True) @@ -263,8 +254,6 @@ class DownloadQueueService(DownloadQueueServiceBase): content_length = int(resp.headers.get("content-length", 0)) job.total_bytes = content_length - print(f'DEBUG: In _do_download [1]') - if job.dest.is_dir(): file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL @@ -292,7 +281,6 @@ class DownloadQueueService(DownloadQueueServiceBase): # signal caller that the download is starting. At this point, key fields such as # download_path and total_bytes will be populated. We call it here because the might # discover that the local file is already complete and generate a COMPLETED status. - print(f'DEBUG: In _do_download [2]', flush=True) self._signal_job_started(job) # "range not satisfiable" - local file is at least as large as the remote file @@ -308,15 +296,12 @@ class DownloadQueueService(DownloadQueueServiceBase): elif resp.status_code != 200: raise HTTPError(resp.reason) - print(f'DEBUG: In _do_download [3]', flush=True) - self._logger.debug(f"{job.source}: Downloading {job.download_path}") report_delta = job.total_bytes / 100 # report every 1% change last_report_bytes = 0 # DOWNLOAD LOOP with open(in_progress_path, open_mode) as file: - print(f'DEBUG: In _do_download loop [4]', flush=True) for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE): if job.cancelled: raise DownloadJobCancelledException("Job was cancelled at caller's request") @@ -329,8 +314,6 @@ class DownloadQueueService(DownloadQueueServiceBase): # if we get here we are done and can rename the file to the original dest self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})") in_progress_path.rename(job.download_path) - print(f'DEBUG: In _do_download [5]', flush=True) - def _validate_filename(self, directory: str, filename: str) -> bool: pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260 # hardcoded for windows diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 737f62a064..4f2cdaed8e 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -28,6 +28,7 @@ class InstallStatus(str, Enum): WAITING = "waiting" # waiting to be dequeued DOWNLOADING = "downloading" # downloading of model files in process + DOWNLOADS_DONE = "downloads_done" # downloading done, waiting to run RUNNING = "running" # being processed COMPLETED = "completed" # finished running ERROR = "error" # terminated with an error message @@ -229,6 +230,11 @@ class ModelInstallJob(BaseModel): """Return true if job is downloading.""" return self.status == InstallStatus.DOWNLOADING + @property + def downloads_done(self) -> bool: + """Return true if job's downloads ae done.""" + return self.status == InstallStatus.DOWNLOADS_DONE + @property def running(self) -> bool: """Return true if job is running.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index fe8124923f..93287a40c6 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -28,7 +28,6 @@ from invokeai.backend.model_manager.config import ( ModelRepoVariant, ModelType, ) -from invokeai.backend.model_manager.hash import FastModelHash from invokeai.backend.model_manager.metadata import ( AnyModelRepoMetadata, CivitaiMetadataFetch, @@ -153,7 +152,6 @@ class ModelInstallService(ModelInstallServiceBase): config["source"] = model_path.resolve().as_posix() info: AnyModelConfig = self._probe_model(Path(model_path), config) - old_hash = info.current_hash if preferred_name := config.get("name"): preferred_name = Path(preferred_name).with_suffix(model_path.suffix) @@ -167,8 +165,6 @@ class ModelInstallService(ModelInstallServiceBase): raise DuplicateModelException( f"A model named {model_path.name} is already installed at {dest_path.as_posix()}" ) from excp - new_hash = FastModelHash.hash(new_path) - assert new_hash == old_hash, f"{model_path}: Model hash changed during installation, possibly corrupted." return self._register( new_path, @@ -370,7 +366,7 @@ class ModelInstallService(ModelInstallServiceBase): self._signal_job_errored(job) elif ( - job.waiting or job.downloading + job.waiting or job.downloads_done ): # local jobs will be in waiting state, remote jobs will be downloading state job.total_bytes = self._stat_size(job.local_path) job.bytes = job.total_bytes @@ -448,7 +444,7 @@ class ModelInstallService(ModelInstallServiceBase): installed.update(self.scan_directory(models_dir)) self._logger.info(f"{len(installed)} new models registered; {len(defunct_models)} unregistered") - def _sync_model_path(self, key: str, ignore_hash_change: bool = False) -> AnyModelConfig: + def _sync_model_path(self, key: str) -> AnyModelConfig: """ Move model into the location indicated by its basetype, type and name. @@ -469,14 +465,7 @@ class ModelInstallService(ModelInstallServiceBase): new_path = models_dir / model.base.value / model.type.value / model.name self._logger.info(f"Moving {model.name} to {new_path}.") new_path = self._move_model(old_path, new_path) - new_hash = FastModelHash.hash(new_path) model.path = new_path.relative_to(models_dir).as_posix() - if model.current_hash != new_hash: - assert ( - ignore_hash_change - ), f"{model.name}: Model hash changed during installation, model is possibly corrupted" - model.current_hash = new_hash - self._logger.info(f"Model has new hash {model.current_hash}, but will continue to be identified by {key}") self.record_store.update_model(key, model) return model @@ -749,8 +738,8 @@ class ModelInstallService(ModelInstallServiceBase): self._download_cache.pop(download_job.source, None) # are there any more active jobs left in this task? - if all(x.complete for x in install_job.download_parts): - # now enqueue job for actual installation into the models directory + if install_job.downloading and all(x.complete for x in install_job.download_parts): + install_job.status = InstallStatus.DOWNLOADS_DONE self._install_queue.put(install_job) # Let other threads know that the number of downloads has changed diff --git a/pyproject.toml b/pyproject.toml index 661adfe4c4..26db5a63c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -187,10 +187,10 @@ version = { attr = "invokeai.version.__version__" } #=== Begin: PyTest and Coverage [tool.pytest.ini_options] -addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers -m \"not slow\"" +addopts = "--cov-report term --cov-report html --cov-report xml --strict-markers --timeout 60 -m \"not slow\"" markers = [ "slow: Marks tests as slow. Disabled by default. To run all tests, use -m \"\". To run only slow tests, use -m \"slow\".", - "timeout: 60" + "timeout: Marks the timeout override." ] [tool.coverage.run] branch = true diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 77d77836ec..543703d713 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -166,6 +166,7 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: # assert re.search("division by zero", captured.err) queue.stop() + @pytest.mark.timeout(timeout=15, method="thread") def test_cancel(tmp_path: Path, session: Session) -> None: event_bus = TestEventService() diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index c0a6e7b2b6..c1e588089b 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -220,6 +220,7 @@ def test_simple_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: event_names = [x.event_name for x in bus.events] assert event_names == ["model_install_downloading", "model_install_running", "model_install_completed"] + @pytest.mark.timeout(timeout=20, method="thread") def test_huggingface_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: source = URLModelSource(url=Url("https://huggingface.co/stabilityai/sdxl-turbo")) diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index df54e2f926..fce72cb04d 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -2,6 +2,7 @@ import os import shutil +import time from pathlib import Path from typing import Any, Dict, List @@ -149,6 +150,7 @@ def mm2_installer( def stop_installer() -> None: installer.stop() + time.sleep(0.1) # avoid error message from the logger when it is closed before thread prints final message request.addfinalizer(stop_installer) return installer