mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up model downloader status locking to avoid race conditions
This commit is contained in:
@ -40,6 +40,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from .cli.completer import set_autocompleter
|
||||
from .invocations.baseinvocation import BaseInvocation
|
||||
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
||||
from .services.download_manager import DownloadQueueService
|
||||
from .services.events import EventServiceBase
|
||||
from .services.graph import (
|
||||
Edge,
|
||||
@ -56,7 +57,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||
from .services.model_install_service import ModelInstallService
|
||||
from .services.model_loader_service import ModelLoadService
|
||||
from .services.download_manager import DownloadQueueService
|
||||
from .services.model_record_service import ModelRecordServiceBase
|
||||
from .services.processor import DefaultInvocationProcessor
|
||||
from .services.sqlite import SqliteItemStorage
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||
"""Abstract base class for a multithreaded model download queue."""
|
||||
|
||||
import threading
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
|
@ -196,6 +196,8 @@ class ModelDownloadQueue(DownloadQueue):
|
||||
def subdownload_event(subjob: DownloadJobBase):
|
||||
assert isinstance(subjob, DownloadJobRemoteSource)
|
||||
assert isinstance(job, DownloadJobRemoteSource)
|
||||
if job.status != DownloadJobStatus.RUNNING: # do not update if we are cancelled or paused
|
||||
return
|
||||
if subjob.status == DownloadJobStatus.RUNNING:
|
||||
bytes_downloaded[subjob.id] = subjob.bytes
|
||||
job.bytes = sum(bytes_downloaded.values())
|
||||
@ -214,13 +216,15 @@ class ModelDownloadQueue(DownloadQueue):
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
return
|
||||
|
||||
subqueue = self.__class__(
|
||||
assert isinstance(job, DownloadJobRepoID)
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
self._lock.acquire() # prevent status from being updated while we are setting up subqueue
|
||||
try:
|
||||
job.subqueue = self.__class__(
|
||||
event_handlers=[subdownload_event],
|
||||
requests_session=self._requests,
|
||||
quiet=True,
|
||||
)
|
||||
assert isinstance(job, DownloadJobRepoID)
|
||||
try:
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
if not job.metadata:
|
||||
@ -235,7 +239,7 @@ class ModelDownloadQueue(DownloadQueue):
|
||||
|
||||
for url, subdir, file, size in urls_to_download:
|
||||
job.total_bytes += size
|
||||
subqueue.create_download_job(
|
||||
job.subqueue.create_download_job(
|
||||
source=url,
|
||||
destdir=job.destination / subdir,
|
||||
filename=file,
|
||||
@ -249,11 +253,11 @@ class ModelDownloadQueue(DownloadQueue):
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
self._logger.error(job.error)
|
||||
finally:
|
||||
job.subqueue = subqueue
|
||||
self._lock.release()
|
||||
if job.subqueue is not None:
|
||||
job.subqueue.join()
|
||||
if job.status == DownloadJobStatus.RUNNING:
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
job.subqueue.release() # get rid of the subqueue
|
||||
|
||||
def _get_repo_info(
|
||||
self,
|
||||
|
@ -176,22 +176,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
If it is running it will be stopped.
|
||||
job.status will be set to DownloadJobStatus.CANCELLED
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.preserve_partial_downloads = preserve_partial
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||
"""Translate a job ID into a DownloadJobBase object."""
|
||||
try:
|
||||
@ -224,6 +208,22 @@ class DownloadQueue(DownloadQueueBase):
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
If it is running it will be stopped.
|
||||
job.status will be set to DownloadJobStatus.CANCELLED
|
||||
"""
|
||||
with self._lock:
|
||||
try:
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.preserve_partial_downloads = preserve_partial
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
job.cleanup()
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def start_all_jobs(self):
|
||||
"""Start (enqueue) all jobs that are idle or paused."""
|
||||
with self._lock:
|
||||
@ -273,9 +273,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if job == STOP_JOB: # marker that queue is done
|
||||
done = True
|
||||
|
||||
if (
|
||||
job.status == DownloadJobStatus.ENQUEUED
|
||||
): # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||
if job.status == DownloadJobStatus.ENQUEUED:
|
||||
if not self._quiet:
|
||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||
do_download = self.select_downloader(job)
|
||||
@ -393,17 +391,18 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
|
||||
"""Optionally change the job status and send an event indicating a change of state."""
|
||||
with self._lock:
|
||||
if new_status:
|
||||
job.status = new_status
|
||||
|
||||
self._logger.debug(f"Status update for download job {job.id}: {job}")
|
||||
if self._in_terminal_state(job) and not self._quiet:
|
||||
self._logger.info(f"{job.source}: download job completed with status {job.status.value}")
|
||||
self._logger.info(f"{job.source}: Download job completed with status {job.status.value}")
|
||||
|
||||
if new_status == DownloadJobStatus.RUNNING and not job.job_started:
|
||||
job.job_started = time.time()
|
||||
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
|
||||
job.job_ended = time.time()
|
||||
|
||||
if job.event_handlers:
|
||||
for handler in job.event_handlers:
|
||||
try:
|
||||
@ -428,9 +427,9 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJobBase):
|
||||
if job.preserve_partial_downloads:
|
||||
return
|
||||
self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}")
|
||||
job.cleanup(job.preserve_partial_downloads)
|
||||
if not job.preserve_partial_downloads:
|
||||
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.destination}")
|
||||
dest = Path(job.destination)
|
||||
if dest.is_file():
|
||||
dest.unlink()
|
||||
|
@ -114,9 +114,9 @@ hf_sd2_paths = [
|
||||
]
|
||||
for path in hf_sd2_paths:
|
||||
url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}"
|
||||
path = Path(path)
|
||||
filename = path.name
|
||||
content = b"This is the content for path " + bytearray(path.as_posix(), "utf-8")
|
||||
path = Path(path).as_posix()
|
||||
filename = Path(path).name
|
||||
content = b"This is the content for path " + bytearray(path, "utf-8")
|
||||
session.mount(
|
||||
url,
|
||||
TestAdapter(
|
||||
@ -314,9 +314,10 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
||||
queue.prune_jobs()
|
||||
assert len(queue.list_jobs()) == 0
|
||||
|
||||
def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions
|
||||
|
||||
def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions
|
||||
def event_handler(job: DownloadJobBase):
|
||||
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
|
||||
time.sleep(0.1) # slow down the thread by blocking it just a bit at every step
|
||||
|
||||
if not INTERNET_AVAILABLE:
|
||||
return
|
||||
@ -338,7 +339,7 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
||||
assert job1.status == "paused"
|
||||
|
||||
queue.start_job(job1)
|
||||
time.sleep(0.1)
|
||||
time.sleep(0.5)
|
||||
assert job1.status == "running"
|
||||
|
||||
# check cancel
|
||||
@ -346,8 +347,6 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
||||
time.sleep(0.1)
|
||||
assert job2.status == "running"
|
||||
queue.cancel_job(job2)
|
||||
time.sleep(0.1)
|
||||
assert job2.status == "cancelled"
|
||||
|
||||
queue.join()
|
||||
assert job1.status == "completed"
|
||||
@ -358,4 +357,6 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
||||
tmpdir2, "stable-diffusion-2-1", "model_index.json"
|
||||
).exists(), "cancelled file should be deleted"
|
||||
|
||||
assert len(queue.list_jobs()) == 2
|
||||
queue.prune_jobs()
|
||||
assert len(queue.list_jobs()) == 0
|
||||
|
Reference in New Issue
Block a user