From a51b165a40bfe3bd85ed6e1db6cdf8da7cc4734c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 12 Oct 2023 13:07:09 -0400 Subject: [PATCH] clean up model downloader status locking to avoid race conditions --- invokeai/app/cli_app.py | 2 +- .../backend/model_manager/download/base.py | 1 + .../model_manager/download/model_queue.py | 22 +++--- .../backend/model_manager/download/queue.py | 71 ++++++++--------- tests/AC_model_manager/test_model_download.py | 79 ++++++++++--------- 5 files changed, 90 insertions(+), 85 deletions(-) diff --git a/invokeai/app/cli_app.py b/invokeai/app/cli_app.py index b411b63da3..5e203eefc0 100644 --- a/invokeai/app/cli_app.py +++ b/invokeai/app/cli_app.py @@ -40,6 +40,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c from .cli.completer import set_autocompleter from .invocations.baseinvocation import BaseInvocation from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id + from .services.download_manager import DownloadQueueService from .services.events import EventServiceBase from .services.graph import ( Edge, @@ -56,7 +57,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.model_install_service import ModelInstallService from .services.model_loader_service import ModelLoadService - from .services.download_manager import DownloadQueueService from .services.model_record_service import ModelRecordServiceBase from .services.processor import DefaultInvocationProcessor from .services.sqlite import SqliteItemStorage diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index ac00de766d..ee9cc4eeec 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team """Abstract base class for a multithreaded model download queue.""" +import threading from abc import ABC, abstractmethod from enum import Enum from functools import total_ordering diff --git a/invokeai/backend/model_manager/download/model_queue.py b/invokeai/backend/model_manager/download/model_queue.py index 270df4c798..2aded8d3ae 100644 --- a/invokeai/backend/model_manager/download/model_queue.py +++ b/invokeai/backend/model_manager/download/model_queue.py @@ -196,6 +196,8 @@ class ModelDownloadQueue(DownloadQueue): def subdownload_event(subjob: DownloadJobBase): assert isinstance(subjob, DownloadJobRemoteSource) assert isinstance(job, DownloadJobRemoteSource) + if job.status != DownloadJobStatus.RUNNING: # do not update if we are cancelled or paused + return if subjob.status == DownloadJobStatus.RUNNING: bytes_downloaded[subjob.id] = subjob.bytes job.bytes = sum(bytes_downloaded.values()) @@ -214,13 +216,15 @@ class ModelDownloadQueue(DownloadQueue): self._update_job_status(job, DownloadJobStatus.RUNNING) return - subqueue = self.__class__( - event_handlers=[subdownload_event], - requests_session=self._requests, - quiet=True, - ) assert isinstance(job, DownloadJobRepoID) + self._update_job_status(job, DownloadJobStatus.RUNNING) + self._lock.acquire() # prevent status from being updated while we are setting up subqueue try: + job.subqueue = self.__class__( + event_handlers=[subdownload_event], + requests_session=self._requests, + quiet=True, + ) repo_id = job.source variant = job.variant if not job.metadata: @@ -235,7 +239,7 @@ class ModelDownloadQueue(DownloadQueue): for url, subdir, file, size in urls_to_download: job.total_bytes += size - subqueue.create_download_job( + job.subqueue.create_download_job( source=url, destdir=job.destination / subdir, filename=file, @@ -249,11 +253,11 @@ class ModelDownloadQueue(DownloadQueue): self._update_job_status(job, DownloadJobStatus.ERROR) self._logger.error(job.error) finally: - job.subqueue = subqueue - job.subqueue.join() + self._lock.release() + if job.subqueue is not None: + job.subqueue.join() if job.status == DownloadJobStatus.RUNNING: self._update_job_status(job, DownloadJobStatus.COMPLETED) - job.subqueue.release() # get rid of the subqueue def _get_repo_info( self, diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index 04264d6267..ca3a243b14 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -176,22 +176,6 @@ class DownloadQueue(DownloadQueueBase): except KeyError as excp: raise UnknownJobIDException("Unrecognized job") from excp - def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False): - """ - Cancel the indicated job. - - If it is running it will be stopped. - job.status will be set to DownloadJobStatus.CANCELLED - """ - with self._lock: - try: - assert isinstance(self._jobs[job.id], DownloadJobBase) - job.preserve_partial_downloads = preserve_partial - self._update_job_status(job, DownloadJobStatus.CANCELLED) - job.cleanup() - except (AssertionError, KeyError) as excp: - raise UnknownJobIDException("Unrecognized job") from excp - def id_to_job(self, id: int) -> DownloadJobBase: """Translate a job ID into a DownloadJobBase object.""" try: @@ -224,6 +208,22 @@ class DownloadQueue(DownloadQueueBase): except (AssertionError, KeyError) as excp: raise UnknownJobIDException("Unrecognized job") from excp + def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False): + """ + Cancel the indicated job. + + If it is running it will be stopped. + job.status will be set to DownloadJobStatus.CANCELLED + """ + with self._lock: + try: + assert isinstance(self._jobs[job.id], DownloadJobBase) + job.preserve_partial_downloads = preserve_partial + self._update_job_status(job, DownloadJobStatus.CANCELLED) + job.cleanup() + except (AssertionError, KeyError) as excp: + raise UnknownJobIDException("Unrecognized job") from excp + def start_all_jobs(self): """Start (enqueue) all jobs that are idle or paused.""" with self._lock: @@ -273,9 +273,7 @@ class DownloadQueue(DownloadQueueBase): if job == STOP_JOB: # marker that queue is done done = True - if ( - job.status == DownloadJobStatus.ENQUEUED - ): # Don't do anything for non-enqueued jobs (shouldn't happen) + if job.status == DownloadJobStatus.ENQUEUED: if not self._quiet: self._logger.info(f"{job.source}: Downloading to {job.destination}") do_download = self.select_downloader(job) @@ -393,17 +391,18 @@ class DownloadQueue(DownloadQueueBase): def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None): """Optionally change the job status and send an event indicating a change of state.""" - if new_status: - job.status = new_status + with self._lock: + if new_status: + job.status = new_status - self._logger.debug(f"Status update for download job {job.id}: {job}") - if self._in_terminal_state(job) and not self._quiet: - self._logger.info(f"{job.source}: download job completed with status {job.status.value}") + if self._in_terminal_state(job) and not self._quiet: + self._logger.info(f"{job.source}: Download job completed with status {job.status.value}") + + if new_status == DownloadJobStatus.RUNNING and not job.job_started: + job.job_started = time.time() + elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]: + job.job_ended = time.time() - if new_status == DownloadJobStatus.RUNNING and not job.job_started: - job.job_started = time.time() - elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]: - job.job_ended = time.time() if job.event_handlers: for handler in job.event_handlers: try: @@ -428,11 +427,11 @@ class DownloadQueue(DownloadQueueBase): self._update_job_status(job, DownloadJobStatus.ERROR) def _cleanup_cancelled_job(self, job: DownloadJobBase): - if job.preserve_partial_downloads: - return - self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}") - dest = Path(job.destination) - if dest.is_file(): - dest.unlink() - elif dest.is_dir(): - shutil.rmtree(dest.as_posix(), ignore_errors=True) + job.cleanup(job.preserve_partial_downloads) + if not job.preserve_partial_downloads: + self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.destination}") + dest = Path(job.destination) + if dest.is_file(): + dest.unlink() + elif dest.is_dir(): + shutil.rmtree(dest.as_posix(), ignore_errors=True) diff --git a/tests/AC_model_manager/test_model_download.py b/tests/AC_model_manager/test_model_download.py index 8ea77d164c..f2803d145d 100644 --- a/tests/AC_model_manager/test_model_download.py +++ b/tests/AC_model_manager/test_model_download.py @@ -114,9 +114,9 @@ hf_sd2_paths = [ ] for path in hf_sd2_paths: url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}" - path = Path(path) - filename = path.name - content = b"This is the content for path " + bytearray(path.as_posix(), "utf-8") + path = Path(path).as_posix() + filename = Path(path).name + content = b"This is the content for path " + bytearray(path, "utf-8") session.mount( url, TestAdapter( @@ -314,48 +314,49 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con queue.prune_jobs() assert len(queue.list_jobs()) == 0 - def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions - def event_handler(job: DownloadJobBase): - time.sleep(0.5) # slow down the thread by blocking it just a bit at every step - if not INTERNET_AVAILABLE: - return +def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions + def event_handler(job: DownloadJobBase): + time.sleep(0.1) # slow down the thread by blocking it just a bit at every step - repo_id = "stabilityai/stable-diffusion-2-1" - queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler]) + if not INTERNET_AVAILABLE: + return - with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2: - job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False) - job2 = queue.create_download_job(source=repo_id, destdir=tmpdir2, variant="fp16", start=False) - assert job1.status == "idle" - queue.start_job(job1) - time.sleep(0.1) # wait for enqueueing - assert job1.status in ["enqueued", "running"] + repo_id = "stabilityai/stable-diffusion-2-1" + queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler]) - # check pause and restart - queue.pause_job(job1) - time.sleep(0.1) # wait to be paused - assert job1.status == "paused" + with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2: + job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False) + job2 = queue.create_download_job(source=repo_id, destdir=tmpdir2, variant="fp16", start=False) + assert job1.status == "idle" + queue.start_job(job1) + time.sleep(0.1) # wait for enqueueing + assert job1.status in ["enqueued", "running"] - queue.start_job(job1) - time.sleep(0.1) - assert job1.status == "running" + # check pause and restart + queue.pause_job(job1) + time.sleep(0.1) # wait to be paused + assert job1.status == "paused" - # check cancel - queue.start_job(job2) - time.sleep(0.1) - assert job2.status == "running" - queue.cancel_job(job2) - time.sleep(0.1) - assert job2.status == "cancelled" + queue.start_job(job1) + time.sleep(0.5) + assert job1.status == "running" - queue.join() - assert job1.status == "completed" - assert job2.status == "cancelled" + # check cancel + queue.start_job(job2) + time.sleep(0.1) + assert job2.status == "running" + queue.cancel_job(job2) - assert Path(tmpdir1, "stable-diffusion-2-1", "model_index.json").exists() - assert not Path( - tmpdir2, "stable-diffusion-2-1", "model_index.json" - ).exists(), "cancelled file should be deleted" + queue.join() + assert job1.status == "completed" + assert job2.status == "cancelled" - assert len(queue.list_jobs()) == 0 + assert Path(tmpdir1, "stable-diffusion-2-1", "model_index.json").exists() + assert not Path( + tmpdir2, "stable-diffusion-2-1", "model_index.json" + ).exists(), "cancelled file should be deleted" + + assert len(queue.list_jobs()) == 2 + queue.prune_jobs() + assert len(queue.list_jobs()) == 0