clean up model downloader status locking to avoid race conditions

This commit is contained in:
Lincoln Stein
2023-10-12 13:07:09 -04:00
parent 5f80d4dd07
commit a51b165a40
5 changed files with 90 additions and 85 deletions

View File

@ -40,6 +40,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from .cli.completer import set_autocompleter from .cli.completer import set_autocompleter
from .invocations.baseinvocation import BaseInvocation from .invocations.baseinvocation import BaseInvocation
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
from .services.download_manager import DownloadQueueService
from .services.events import EventServiceBase from .services.events import EventServiceBase
from .services.graph import ( from .services.graph import (
Edge, Edge,
@ -56,7 +57,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
from .services.model_install_service import ModelInstallService from .services.model_install_service import ModelInstallService
from .services.model_loader_service import ModelLoadService from .services.model_loader_service import ModelLoadService
from .services.download_manager import DownloadQueueService
from .services.model_record_service import ModelRecordServiceBase from .services.model_record_service import ModelRecordServiceBase
from .services.processor import DefaultInvocationProcessor from .services.processor import DefaultInvocationProcessor
from .services.sqlite import SqliteItemStorage from .services.sqlite import SqliteItemStorage

View File

@ -1,6 +1,7 @@
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team # Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
"""Abstract base class for a multithreaded model download queue.""" """Abstract base class for a multithreaded model download queue."""
import threading
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import total_ordering from functools import total_ordering

View File

@ -196,6 +196,8 @@ class ModelDownloadQueue(DownloadQueue):
def subdownload_event(subjob: DownloadJobBase): def subdownload_event(subjob: DownloadJobBase):
assert isinstance(subjob, DownloadJobRemoteSource) assert isinstance(subjob, DownloadJobRemoteSource)
assert isinstance(job, DownloadJobRemoteSource) assert isinstance(job, DownloadJobRemoteSource)
if job.status != DownloadJobStatus.RUNNING: # do not update if we are cancelled or paused
return
if subjob.status == DownloadJobStatus.RUNNING: if subjob.status == DownloadJobStatus.RUNNING:
bytes_downloaded[subjob.id] = subjob.bytes bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values()) job.bytes = sum(bytes_downloaded.values())
@ -214,13 +216,15 @@ class ModelDownloadQueue(DownloadQueue):
self._update_job_status(job, DownloadJobStatus.RUNNING) self._update_job_status(job, DownloadJobStatus.RUNNING)
return return
subqueue = self.__class__( assert isinstance(job, DownloadJobRepoID)
self._update_job_status(job, DownloadJobStatus.RUNNING)
self._lock.acquire() # prevent status from being updated while we are setting up subqueue
try:
job.subqueue = self.__class__(
event_handlers=[subdownload_event], event_handlers=[subdownload_event],
requests_session=self._requests, requests_session=self._requests,
quiet=True, quiet=True,
) )
assert isinstance(job, DownloadJobRepoID)
try:
repo_id = job.source repo_id = job.source
variant = job.variant variant = job.variant
if not job.metadata: if not job.metadata:
@ -235,7 +239,7 @@ class ModelDownloadQueue(DownloadQueue):
for url, subdir, file, size in urls_to_download: for url, subdir, file, size in urls_to_download:
job.total_bytes += size job.total_bytes += size
subqueue.create_download_job( job.subqueue.create_download_job(
source=url, source=url,
destdir=job.destination / subdir, destdir=job.destination / subdir,
filename=file, filename=file,
@ -249,11 +253,11 @@ class ModelDownloadQueue(DownloadQueue):
self._update_job_status(job, DownloadJobStatus.ERROR) self._update_job_status(job, DownloadJobStatus.ERROR)
self._logger.error(job.error) self._logger.error(job.error)
finally: finally:
job.subqueue = subqueue self._lock.release()
if job.subqueue is not None:
job.subqueue.join() job.subqueue.join()
if job.status == DownloadJobStatus.RUNNING: if job.status == DownloadJobStatus.RUNNING:
self._update_job_status(job, DownloadJobStatus.COMPLETED) self._update_job_status(job, DownloadJobStatus.COMPLETED)
job.subqueue.release() # get rid of the subqueue
def _get_repo_info( def _get_repo_info(
self, self,

View File

@ -176,22 +176,6 @@ class DownloadQueue(DownloadQueueBase):
except KeyError as excp: except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
try:
assert isinstance(self._jobs[job.id], DownloadJobBase)
job.preserve_partial_downloads = preserve_partial
self._update_job_status(job, DownloadJobStatus.CANCELLED)
job.cleanup()
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def id_to_job(self, id: int) -> DownloadJobBase: def id_to_job(self, id: int) -> DownloadJobBase:
"""Translate a job ID into a DownloadJobBase object.""" """Translate a job ID into a DownloadJobBase object."""
try: try:
@ -224,6 +208,22 @@ class DownloadQueue(DownloadQueueBase):
except (AssertionError, KeyError) as excp: except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
"""
Cancel the indicated job.
If it is running it will be stopped.
job.status will be set to DownloadJobStatus.CANCELLED
"""
with self._lock:
try:
assert isinstance(self._jobs[job.id], DownloadJobBase)
job.preserve_partial_downloads = preserve_partial
self._update_job_status(job, DownloadJobStatus.CANCELLED)
job.cleanup()
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def start_all_jobs(self): def start_all_jobs(self):
"""Start (enqueue) all jobs that are idle or paused.""" """Start (enqueue) all jobs that are idle or paused."""
with self._lock: with self._lock:
@ -273,9 +273,7 @@ class DownloadQueue(DownloadQueueBase):
if job == STOP_JOB: # marker that queue is done if job == STOP_JOB: # marker that queue is done
done = True done = True
if ( if job.status == DownloadJobStatus.ENQUEUED:
job.status == DownloadJobStatus.ENQUEUED
): # Don't do anything for non-enqueued jobs (shouldn't happen)
if not self._quiet: if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}") self._logger.info(f"{job.source}: Downloading to {job.destination}")
do_download = self.select_downloader(job) do_download = self.select_downloader(job)
@ -393,17 +391,18 @@ class DownloadQueue(DownloadQueueBase):
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None): def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
"""Optionally change the job status and send an event indicating a change of state.""" """Optionally change the job status and send an event indicating a change of state."""
with self._lock:
if new_status: if new_status:
job.status = new_status job.status = new_status
self._logger.debug(f"Status update for download job {job.id}: {job}")
if self._in_terminal_state(job) and not self._quiet: if self._in_terminal_state(job) and not self._quiet:
self._logger.info(f"{job.source}: download job completed with status {job.status.value}") self._logger.info(f"{job.source}: Download job completed with status {job.status.value}")
if new_status == DownloadJobStatus.RUNNING and not job.job_started: if new_status == DownloadJobStatus.RUNNING and not job.job_started:
job.job_started = time.time() job.job_started = time.time()
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]: elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
job.job_ended = time.time() job.job_ended = time.time()
if job.event_handlers: if job.event_handlers:
for handler in job.event_handlers: for handler in job.event_handlers:
try: try:
@ -428,9 +427,9 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.ERROR) self._update_job_status(job, DownloadJobStatus.ERROR)
def _cleanup_cancelled_job(self, job: DownloadJobBase): def _cleanup_cancelled_job(self, job: DownloadJobBase):
if job.preserve_partial_downloads: job.cleanup(job.preserve_partial_downloads)
return if not job.preserve_partial_downloads:
self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}") self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.destination}")
dest = Path(job.destination) dest = Path(job.destination)
if dest.is_file(): if dest.is_file():
dest.unlink() dest.unlink()

View File

@ -114,9 +114,9 @@ hf_sd2_paths = [
] ]
for path in hf_sd2_paths: for path in hf_sd2_paths:
url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}" url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}"
path = Path(path) path = Path(path).as_posix()
filename = path.name filename = Path(path).name
content = b"This is the content for path " + bytearray(path.as_posix(), "utf-8") content = b"This is the content for path " + bytearray(path, "utf-8")
session.mount( session.mount(
url, url,
TestAdapter( TestAdapter(
@ -314,9 +314,10 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
queue.prune_jobs() queue.prune_jobs()
assert len(queue.list_jobs()) == 0 assert len(queue.list_jobs()) == 0
def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase): def event_handler(job: DownloadJobBase):
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step time.sleep(0.1) # slow down the thread by blocking it just a bit at every step
if not INTERNET_AVAILABLE: if not INTERNET_AVAILABLE:
return return
@ -338,7 +339,7 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
assert job1.status == "paused" assert job1.status == "paused"
queue.start_job(job1) queue.start_job(job1)
time.sleep(0.1) time.sleep(0.5)
assert job1.status == "running" assert job1.status == "running"
# check cancel # check cancel
@ -346,8 +347,6 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
time.sleep(0.1) time.sleep(0.1)
assert job2.status == "running" assert job2.status == "running"
queue.cancel_job(job2) queue.cancel_job(job2)
time.sleep(0.1)
assert job2.status == "cancelled"
queue.join() queue.join()
assert job1.status == "completed" assert job1.status == "completed"
@ -358,4 +357,6 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
tmpdir2, "stable-diffusion-2-1", "model_index.json" tmpdir2, "stable-diffusion-2-1", "model_index.json"
).exists(), "cancelled file should be deleted" ).exists(), "cancelled file should be deleted"
assert len(queue.list_jobs()) == 2
queue.prune_jobs()
assert len(queue.list_jobs()) == 0 assert len(queue.list_jobs()) == 0