mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up model downloader status locking to avoid race conditions
This commit is contained in:
@ -40,6 +40,7 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from .cli.completer import set_autocompleter
|
from .cli.completer import set_autocompleter
|
||||||
from .invocations.baseinvocation import BaseInvocation
|
from .invocations.baseinvocation import BaseInvocation
|
||||||
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
from .services.default_graphs import create_system_graphs, default_text_to_image_graph_id
|
||||||
|
from .services.download_manager import DownloadQueueService
|
||||||
from .services.events import EventServiceBase
|
from .services.events import EventServiceBase
|
||||||
from .services.graph import (
|
from .services.graph import (
|
||||||
Edge,
|
Edge,
|
||||||
@ -56,7 +57,6 @@ if True: # hack to make flake8 happy with imports coming after setting up the c
|
|||||||
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
from .services.latent_storage import DiskLatentsStorage, ForwardCacheLatentsStorage
|
||||||
from .services.model_install_service import ModelInstallService
|
from .services.model_install_service import ModelInstallService
|
||||||
from .services.model_loader_service import ModelLoadService
|
from .services.model_loader_service import ModelLoadService
|
||||||
from .services.download_manager import DownloadQueueService
|
|
||||||
from .services.model_record_service import ModelRecordServiceBase
|
from .services.model_record_service import ModelRecordServiceBase
|
||||||
from .services.processor import DefaultInvocationProcessor
|
from .services.processor import DefaultInvocationProcessor
|
||||||
from .services.sqlite import SqliteItemStorage
|
from .services.sqlite import SqliteItemStorage
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
"""Abstract base class for a multithreaded model download queue."""
|
"""Abstract base class for a multithreaded model download queue."""
|
||||||
|
|
||||||
|
import threading
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import total_ordering
|
from functools import total_ordering
|
||||||
|
@ -196,6 +196,8 @@ class ModelDownloadQueue(DownloadQueue):
|
|||||||
def subdownload_event(subjob: DownloadJobBase):
|
def subdownload_event(subjob: DownloadJobBase):
|
||||||
assert isinstance(subjob, DownloadJobRemoteSource)
|
assert isinstance(subjob, DownloadJobRemoteSource)
|
||||||
assert isinstance(job, DownloadJobRemoteSource)
|
assert isinstance(job, DownloadJobRemoteSource)
|
||||||
|
if job.status != DownloadJobStatus.RUNNING: # do not update if we are cancelled or paused
|
||||||
|
return
|
||||||
if subjob.status == DownloadJobStatus.RUNNING:
|
if subjob.status == DownloadJobStatus.RUNNING:
|
||||||
bytes_downloaded[subjob.id] = subjob.bytes
|
bytes_downloaded[subjob.id] = subjob.bytes
|
||||||
job.bytes = sum(bytes_downloaded.values())
|
job.bytes = sum(bytes_downloaded.values())
|
||||||
@ -214,13 +216,15 @@ class ModelDownloadQueue(DownloadQueue):
|
|||||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||||
return
|
return
|
||||||
|
|
||||||
subqueue = self.__class__(
|
|
||||||
event_handlers=[subdownload_event],
|
|
||||||
requests_session=self._requests,
|
|
||||||
quiet=True,
|
|
||||||
)
|
|
||||||
assert isinstance(job, DownloadJobRepoID)
|
assert isinstance(job, DownloadJobRepoID)
|
||||||
|
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||||
|
self._lock.acquire() # prevent status from being updated while we are setting up subqueue
|
||||||
try:
|
try:
|
||||||
|
job.subqueue = self.__class__(
|
||||||
|
event_handlers=[subdownload_event],
|
||||||
|
requests_session=self._requests,
|
||||||
|
quiet=True,
|
||||||
|
)
|
||||||
repo_id = job.source
|
repo_id = job.source
|
||||||
variant = job.variant
|
variant = job.variant
|
||||||
if not job.metadata:
|
if not job.metadata:
|
||||||
@ -235,7 +239,7 @@ class ModelDownloadQueue(DownloadQueue):
|
|||||||
|
|
||||||
for url, subdir, file, size in urls_to_download:
|
for url, subdir, file, size in urls_to_download:
|
||||||
job.total_bytes += size
|
job.total_bytes += size
|
||||||
subqueue.create_download_job(
|
job.subqueue.create_download_job(
|
||||||
source=url,
|
source=url,
|
||||||
destdir=job.destination / subdir,
|
destdir=job.destination / subdir,
|
||||||
filename=file,
|
filename=file,
|
||||||
@ -249,11 +253,11 @@ class ModelDownloadQueue(DownloadQueue):
|
|||||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||||
self._logger.error(job.error)
|
self._logger.error(job.error)
|
||||||
finally:
|
finally:
|
||||||
job.subqueue = subqueue
|
self._lock.release()
|
||||||
job.subqueue.join()
|
if job.subqueue is not None:
|
||||||
|
job.subqueue.join()
|
||||||
if job.status == DownloadJobStatus.RUNNING:
|
if job.status == DownloadJobStatus.RUNNING:
|
||||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||||
job.subqueue.release() # get rid of the subqueue
|
|
||||||
|
|
||||||
def _get_repo_info(
|
def _get_repo_info(
|
||||||
self,
|
self,
|
||||||
|
@ -176,22 +176,6 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
except KeyError as excp:
|
except KeyError as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
|
|
||||||
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
|
||||||
"""
|
|
||||||
Cancel the indicated job.
|
|
||||||
|
|
||||||
If it is running it will be stopped.
|
|
||||||
job.status will be set to DownloadJobStatus.CANCELLED
|
|
||||||
"""
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
|
||||||
job.preserve_partial_downloads = preserve_partial
|
|
||||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
|
||||||
job.cleanup()
|
|
||||||
except (AssertionError, KeyError) as excp:
|
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
|
||||||
|
|
||||||
def id_to_job(self, id: int) -> DownloadJobBase:
|
def id_to_job(self, id: int) -> DownloadJobBase:
|
||||||
"""Translate a job ID into a DownloadJobBase object."""
|
"""Translate a job ID into a DownloadJobBase object."""
|
||||||
try:
|
try:
|
||||||
@ -224,6 +208,22 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
except (AssertionError, KeyError) as excp:
|
except (AssertionError, KeyError) as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
|
|
||||||
|
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
||||||
|
"""
|
||||||
|
Cancel the indicated job.
|
||||||
|
|
||||||
|
If it is running it will be stopped.
|
||||||
|
job.status will be set to DownloadJobStatus.CANCELLED
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
try:
|
||||||
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||||
|
job.preserve_partial_downloads = preserve_partial
|
||||||
|
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||||
|
job.cleanup()
|
||||||
|
except (AssertionError, KeyError) as excp:
|
||||||
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
|
|
||||||
def start_all_jobs(self):
|
def start_all_jobs(self):
|
||||||
"""Start (enqueue) all jobs that are idle or paused."""
|
"""Start (enqueue) all jobs that are idle or paused."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
@ -273,9 +273,7 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
if job == STOP_JOB: # marker that queue is done
|
if job == STOP_JOB: # marker that queue is done
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
if (
|
if job.status == DownloadJobStatus.ENQUEUED:
|
||||||
job.status == DownloadJobStatus.ENQUEUED
|
|
||||||
): # Don't do anything for non-enqueued jobs (shouldn't happen)
|
|
||||||
if not self._quiet:
|
if not self._quiet:
|
||||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||||
do_download = self.select_downloader(job)
|
do_download = self.select_downloader(job)
|
||||||
@ -393,17 +391,18 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
|
|
||||||
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
|
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
|
||||||
"""Optionally change the job status and send an event indicating a change of state."""
|
"""Optionally change the job status and send an event indicating a change of state."""
|
||||||
if new_status:
|
with self._lock:
|
||||||
job.status = new_status
|
if new_status:
|
||||||
|
job.status = new_status
|
||||||
|
|
||||||
self._logger.debug(f"Status update for download job {job.id}: {job}")
|
if self._in_terminal_state(job) and not self._quiet:
|
||||||
if self._in_terminal_state(job) and not self._quiet:
|
self._logger.info(f"{job.source}: Download job completed with status {job.status.value}")
|
||||||
self._logger.info(f"{job.source}: download job completed with status {job.status.value}")
|
|
||||||
|
if new_status == DownloadJobStatus.RUNNING and not job.job_started:
|
||||||
|
job.job_started = time.time()
|
||||||
|
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
|
||||||
|
job.job_ended = time.time()
|
||||||
|
|
||||||
if new_status == DownloadJobStatus.RUNNING and not job.job_started:
|
|
||||||
job.job_started = time.time()
|
|
||||||
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
|
|
||||||
job.job_ended = time.time()
|
|
||||||
if job.event_handlers:
|
if job.event_handlers:
|
||||||
for handler in job.event_handlers:
|
for handler in job.event_handlers:
|
||||||
try:
|
try:
|
||||||
@ -428,11 +427,11 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||||
|
|
||||||
def _cleanup_cancelled_job(self, job: DownloadJobBase):
|
def _cleanup_cancelled_job(self, job: DownloadJobBase):
|
||||||
if job.preserve_partial_downloads:
|
job.cleanup(job.preserve_partial_downloads)
|
||||||
return
|
if not job.preserve_partial_downloads:
|
||||||
self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}")
|
self._logger.warning(f"Cleaning up leftover files from cancelled download job {job.destination}")
|
||||||
dest = Path(job.destination)
|
dest = Path(job.destination)
|
||||||
if dest.is_file():
|
if dest.is_file():
|
||||||
dest.unlink()
|
dest.unlink()
|
||||||
elif dest.is_dir():
|
elif dest.is_dir():
|
||||||
shutil.rmtree(dest.as_posix(), ignore_errors=True)
|
shutil.rmtree(dest.as_posix(), ignore_errors=True)
|
||||||
|
@ -114,9 +114,9 @@ hf_sd2_paths = [
|
|||||||
]
|
]
|
||||||
for path in hf_sd2_paths:
|
for path in hf_sd2_paths:
|
||||||
url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}"
|
url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}"
|
||||||
path = Path(path)
|
path = Path(path).as_posix()
|
||||||
filename = path.name
|
filename = Path(path).name
|
||||||
content = b"This is the content for path " + bytearray(path.as_posix(), "utf-8")
|
content = b"This is the content for path " + bytearray(path, "utf-8")
|
||||||
session.mount(
|
session.mount(
|
||||||
url,
|
url,
|
||||||
TestAdapter(
|
TestAdapter(
|
||||||
@ -314,48 +314,49 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
|||||||
queue.prune_jobs()
|
queue.prune_jobs()
|
||||||
assert len(queue.list_jobs()) == 0
|
assert len(queue.list_jobs()) == 0
|
||||||
|
|
||||||
def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions
|
|
||||||
def event_handler(job: DownloadJobBase):
|
|
||||||
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
|
|
||||||
|
|
||||||
if not INTERNET_AVAILABLE:
|
def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions
|
||||||
return
|
def event_handler(job: DownloadJobBase):
|
||||||
|
time.sleep(0.1) # slow down the thread by blocking it just a bit at every step
|
||||||
|
|
||||||
repo_id = "stabilityai/stable-diffusion-2-1"
|
if not INTERNET_AVAILABLE:
|
||||||
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
|
return
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
|
repo_id = "stabilityai/stable-diffusion-2-1"
|
||||||
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)
|
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
|
||||||
job2 = queue.create_download_job(source=repo_id, destdir=tmpdir2, variant="fp16", start=False)
|
|
||||||
assert job1.status == "idle"
|
|
||||||
queue.start_job(job1)
|
|
||||||
time.sleep(0.1) # wait for enqueueing
|
|
||||||
assert job1.status in ["enqueued", "running"]
|
|
||||||
|
|
||||||
# check pause and restart
|
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
|
||||||
queue.pause_job(job1)
|
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)
|
||||||
time.sleep(0.1) # wait to be paused
|
job2 = queue.create_download_job(source=repo_id, destdir=tmpdir2, variant="fp16", start=False)
|
||||||
assert job1.status == "paused"
|
assert job1.status == "idle"
|
||||||
|
queue.start_job(job1)
|
||||||
|
time.sleep(0.1) # wait for enqueueing
|
||||||
|
assert job1.status in ["enqueued", "running"]
|
||||||
|
|
||||||
queue.start_job(job1)
|
# check pause and restart
|
||||||
time.sleep(0.1)
|
queue.pause_job(job1)
|
||||||
assert job1.status == "running"
|
time.sleep(0.1) # wait to be paused
|
||||||
|
assert job1.status == "paused"
|
||||||
|
|
||||||
# check cancel
|
queue.start_job(job1)
|
||||||
queue.start_job(job2)
|
time.sleep(0.5)
|
||||||
time.sleep(0.1)
|
assert job1.status == "running"
|
||||||
assert job2.status == "running"
|
|
||||||
queue.cancel_job(job2)
|
|
||||||
time.sleep(0.1)
|
|
||||||
assert job2.status == "cancelled"
|
|
||||||
|
|
||||||
queue.join()
|
# check cancel
|
||||||
assert job1.status == "completed"
|
queue.start_job(job2)
|
||||||
assert job2.status == "cancelled"
|
time.sleep(0.1)
|
||||||
|
assert job2.status == "running"
|
||||||
|
queue.cancel_job(job2)
|
||||||
|
|
||||||
assert Path(tmpdir1, "stable-diffusion-2-1", "model_index.json").exists()
|
queue.join()
|
||||||
assert not Path(
|
assert job1.status == "completed"
|
||||||
tmpdir2, "stable-diffusion-2-1", "model_index.json"
|
assert job2.status == "cancelled"
|
||||||
).exists(), "cancelled file should be deleted"
|
|
||||||
|
|
||||||
assert len(queue.list_jobs()) == 0
|
assert Path(tmpdir1, "stable-diffusion-2-1", "model_index.json").exists()
|
||||||
|
assert not Path(
|
||||||
|
tmpdir2, "stable-diffusion-2-1", "model_index.json"
|
||||||
|
).exists(), "cancelled file should be deleted"
|
||||||
|
|
||||||
|
assert len(queue.list_jobs()) == 2
|
||||||
|
queue.prune_jobs()
|
||||||
|
assert len(queue.list_jobs()) == 0
|
||||||
|
Reference in New Issue
Block a user