implement regression tests for pause/cancel/error conditions

This commit is contained in:
Lincoln Stein
2023-09-07 17:06:59 -04:00
parent 79b2423159
commit a7aca29765
5 changed files with 332 additions and 116 deletions

View File

@ -22,7 +22,7 @@ class DownloadQueueServiceBase(ABC):
start: bool = True,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> int:
) -> DownloadJobBase:
"""
Create a download job.
@ -73,26 +73,26 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def start_job(self, id: int):
def start_job(self, job: DownloadJobBase):
"""Start the job putting it into ENQUEUED state."""
pass
@abstractmethod
def pause_job(self, id: int):
def pause_job(self, job: DownloadJobBase):
"""Pause the job, putting it into PAUSED state."""
pass
@abstractmethod
def cancel_job(self, id: int):
def cancel_job(self, job: DownloadJobBase):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def change_priority(self, id: int, delta: int):
def change_priority(self, job: DownloadJobBase, delta: int):
"""
Change the job's priority.
:param id: ID of the job
:param job: Job to apply change to
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.
@ -132,7 +132,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
start: bool = True,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> int:
) -> DownloadJobBase:
event_handlers = event_handlers or []
if self._event_bus:
event_handlers.append([self._event_bus.emit_model_download_event])
@ -160,16 +160,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
def cancel_all_jobs(self):
return self._queue.cancel_all_jobs()
def start_job(self, id: int):
def start_job(self, job: DownloadJobBase):
return self._queue.start_job(id)
def pause_job(self, id: int):
def pause_job(self, job: DownloadJobBase):
return self._queue.pause_job(id)
def cancel_job(self, id: int):
def cancel_job(self, job: DownloadJobBase):
return self._queue.cancel_job(id)
def change_priority(self, id: int, delta: int):
def change_priority(self, job: DownloadJobBase, delta: int):
return self._queue.change_priority(id, delta)
def join(self):

View File

@ -5,7 +5,6 @@ from .base import ( # noqa F401
DownloadJobStatus,
DownloadEventHandler,
UnknownJobIDException,
CancelledJobException,
DownloadJobBase,
)

View File

@ -20,14 +20,11 @@ class DownloadJobStatus(str, Enum):
PAUSED = "paused" # previously started, now paused
COMPLETED = "completed" # finished running
ERROR = "error" # terminated with an error message
CANCELLED = "cancelled" # terminated by caller
class UnknownJobIDException(Exception):
"""Raised when an invalid Job ID is requested."""
class CancelledJobException(Exception):
"""Raised when a job is cancelled."""
"""Raised when an invalid Job is referenced."""
DownloadEventHandler = Callable[["DownloadJobBase"], None]
@ -85,7 +82,7 @@ class DownloadQueueBase(ABC):
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> int:
) -> DownloadJobBase:
"""
Create a download job.
@ -101,7 +98,7 @@ class DownloadQueueBase(ABC):
pass
@abstractmethod
def release(self) -> int:
def release(self):
"""
Release resources used by queue.
@ -127,7 +124,7 @@ class DownloadQueueBase(ABC):
:param id: ID of the DownloadJobBase.
Exceptions:
* UnknownJobIDException
* UnknownJobException
Note that once a job is completed, id_to_job() may no longer
recognize the job. Call id_to_job() before the job completes
@ -152,26 +149,26 @@ class DownloadQueueBase(ABC):
pass
@abstractmethod
def start_job(self, id: int):
def start_job(self, job: DownloadJobBase):
"""Start the job putting it into ENQUEUED state."""
pass
@abstractmethod
def pause_job(self, id: int):
def pause_job(self, job: DownloadJobBase):
"""Pause the job, putting it into PAUSED state."""
pass
@abstractmethod
def cancel_job(self, id: int):
def cancel_job(self, job: DownloadJobBase):
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@abstractmethod
def change_priority(self, id: int, delta: int):
def change_priority(self, job: DownloadJobBase, delta: int):
"""
Change the job's priority.
:param id: ID of the job
:param job: Job to change
:param delta: Value to increment or decrement priority.
Lower values are higher priority. The default starting value is 10.

View File

@ -4,10 +4,12 @@
import re
import os
import requests
import shutil
import threading
import time
from pathlib import Path
from requests import HTTPError
from typing import Dict, Optional, Set, List, Tuple
from pydantic import Field, validator, ValidationError
@ -22,7 +24,6 @@ from .base import (
DownloadJobStatus,
DownloadEventHandler,
UnknownJobIDException,
CancelledJobException,
DownloadJobBase,
)
@ -95,7 +96,7 @@ class DownloadQueue(DownloadQueueBase):
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
) -> int:
) -> DownloadJobBase:
"""Create a download job and return its ID."""
if re.match(r"^[\w-]+/[\w-]+$", source):
cls = DownloadJobRepoID
@ -119,8 +120,8 @@ class DownloadQueue(DownloadQueueBase):
finally:
self._lock.release()
if start:
self.start_job(id)
return job.id
self.start_job(job)
return job
def release(self):
"""Signal our threads to exit when queue done."""
@ -136,31 +137,31 @@ class DownloadQueue(DownloadQueueBase):
"""List all the jobs."""
return self._jobs.values()
def change_priority(self, id: int, delta: int):
def change_priority(self, job: DownloadJobBase, delta: int):
"""Change the priority of a job. Smaller priorities run first."""
try:
self._lock.acquire()
job = self._jobs[id]
assert isinstance(self._jobs[job.id], DownloadJobBase)
job.priority += delta
except KeyError as excp:
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
def cancel_job(self, id: str):
def cancel_job(self, job: DownloadJobBase):
"""
Cancel the indicated job.
If it is running it will be stopped.
job.error will be set to CancelledJobException.
job.status will be set to DownloadJobStatus.CANCELLED
"""
try:
self._lock.acquire()
job = self._jobs[id]
job.status = DownloadJobStatus.ERROR
job.error = CancelledJobException(f"Job {job.id} cancelled at caller's request")
self._update_job_status
del self._jobs[job.id]
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.CANCELLED)
# del self._jobs[job.id]
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
@ -171,16 +172,16 @@ class DownloadQueue(DownloadQueueBase):
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def start_job(self, id: int):
def start_job(self, job: DownloadJobBase):
"""Enqueue (start) the indicated job."""
try:
job = self._jobs[id]
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
self._queue.put(job)
except KeyError as excp:
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def pause_job(self, id: int):
def pause_job(self, job: DownloadJobBase):
"""
Pause (dequeue) the indicated job.
@ -189,9 +190,9 @@ class DownloadQueue(DownloadQueueBase):
"""
try:
self._lock.acquire()
job = self._jobs[id]
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.PAUSED)
except KeyError as excp:
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:
self._lock.release()
@ -200,13 +201,13 @@ class DownloadQueue(DownloadQueueBase):
"""Start (enqueue) all jobs that are idle or paused."""
try:
self._lock.acquire()
for id, job in self._jobs.items():
for job in self._jobs.values():
if job.status in [DownloadJobStatus.IDLE or DownloadJobStatus.PAUSED]:
self.start_job(id)
self.start_job(job)
finally:
self._lock.release()
def pause_all_jobs(self, id: int):
def pause_all_jobs(self):
"""Pause all running jobs."""
try:
self._lock.acquire()
@ -221,11 +222,18 @@ class DownloadQueue(DownloadQueueBase):
try:
self._lock.acquire()
for id, job in self._jobs.items():
if job.status in [DownloadJobStatus.RUNNING, DownloadJobStatus.PAUSED, DownloadJobStatus.ENQUEUED]:
if not self._in_terminal_state(job):
self.cancel_job(id)
finally:
self._lock.release()
def _in_terminal_state(self, job: DownloadJobBase):
return job.status in [
DownloadJobStatus.COMPLETED,
DownloadJobStatus.ERROR,
DownloadJobStatus.CANCELLED,
]
def _start_workers(self, max_workers: int):
"""Start the requested number of worker threads."""
for i in range(0, max_workers):
@ -247,13 +255,21 @@ class DownloadQueue(DownloadQueueBase):
if job == STOP_JOB: # marker that queue is done
break
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
if isinstance(job, DownloadJobURL):
self._download_with_resume(job)
elif isinstance(job, DownloadJobRepoID):
self._download_repoid(job)
else:
raise NotImplementedError(f"Don't know what to do with this job: {job}")
if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job)
if self._in_terminal_state(job):
del self._jobs[job.id]
self._queue.task_done()
def _download_with_resume(self, job: DownloadJobBase):
@ -277,33 +293,34 @@ class DownloadQueue(DownloadQueueBase):
dest = job.destination
dest.parent.mkdir(parents=True, exist_ok=True)
if dest.exists():
job.bytes = dest.stat().st_size
header["Range"] = f"bytes={job.bytes}-"
open_mode = "ab"
resp = self._requests.get(job.source, headers=header, stream=True) # new request with range
if exist_size > content_length:
self._logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
self._logger.warning(f"{dest}: complete file found. Skipping.")
self._update_job_status(job, DownloadJobStatus.COMPLETED)
return
if resp.status_code == 206 or exist_size > 0:
self._logger.warning(f"{dest}: partial file found. Resuming...")
elif resp.status_code != 200:
self._logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
else:
self._logger.info(f"{dest}: Downloading...")
self._update_job_status(job, DownloadJobStatus.RUNNING)
report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0
try:
if dest.exists():
job.bytes = dest.stat().st_size
header["Range"] = f"bytes={job.bytes}-"
open_mode = "ab"
resp = self._requests.get(job.source, headers=header, stream=True) # new request with range
if exist_size > content_length:
self._logger.warning("corrupt existing file found. re-downloading")
os.remove(dest)
exist_size = 0
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
self._logger.warning(f"{dest}: complete file found. Skipping.")
self._update_job_status(job, DownloadJobStatus.COMPLETED)
return
if resp.status_code == 206 or exist_size > 0:
self._logger.warning(f"{dest}: partial file found. Resuming...")
elif resp.status_code != 200:
raise HTTPError(resp.reason)
else:
self._logger.info(f"{dest}: Downloading...")
report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0
self._update_job_status(job, DownloadJobStatus.RUNNING)
with open(dest, open_mode) as file:
for data in resp.iter_content(chunk_size=16384):
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
@ -314,7 +331,6 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job)
self._update_job_status(job, DownloadJobStatus.COMPLETED)
del self._jobs[job.id]
except Exception as excp:
self._logger.error(f"An error occurred while downloading {dest}: {str(excp)}")
job.error = excp
@ -355,7 +371,10 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
subqueue = self.__class__(event_handlers=[subdownload_event])
subqueue = self.__class__(
event_handlers=[subdownload_event],
requests_session=self._requests,
)
try:
repo_id = job.source
variant = job.variant
@ -431,3 +450,11 @@ class DownloadQueue(DownloadQueueBase):
for v in basenames.values():
result.add(v)
return result
def _cleanup_cancelled_job(self, job: DownloadJobBase):
self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}")
dest = Path(job.destination)
if dest.is_file():
dest.unlink()
elif dest.is_dir():
shutil.rmtree(dest.as_posix(), ignore_errors=True)

View File

@ -1,40 +1,100 @@
"""Test the queued download facility"""
import tempfile
import time
import requests
from requests_testadapter import TestAdapter
from requests import HTTPError
from pathlib import Path
from invokeai.backend.model_manager.download import (
DownloadJobStatus,
DownloadQueue,
DownloadJobBase,
UnknownJobIDException,
)
TestAdapter.__test__ = False
SAFETENSORS1_CONTENT = b"I am a safetensors file (1)"
SAFETENSORS1_HEADER = {
"Content-Length": len(SAFETENSORS1_CONTENT),
"Content-Disposition": 'filename="mock1.safetensors"',
}
SAFETENSORS2_CONTENT = b"I am a safetensors file (2)"
SAFETENSORS2_HEADER = {
"Content-Length": len(SAFETENSORS2_CONTENT),
"Content-Disposition": 'filename="mock2.safetensors"',
}
INTERNET_AVAILABLE = requests.get("http://www.google.com/").status_code == 200
########################################################################################
# Lots of dummy content here to test model download without using lots of bandwidth
# The repo_id tests are not self-contained because they still need to use the HF API
# to retrieve metainformation about the files to retrieve. However, the big weights files
# are not downloaded.
# If the internet is not available, then the repo_id tests are skipped, but the single
# URL tests are still run.
session = requests.Session()
session.mount(
"http://www.civitai.com/models/12345", TestAdapter(SAFETENSORS1_CONTENT, status=200, headers=SAFETENSORS1_HEADER)
)
session.mount(
"http://www.civitai.com/models/9999", TestAdapter(SAFETENSORS2_CONTENT, status=200, headers=SAFETENSORS2_HEADER)
)
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
session.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
# mock HuggingFace URLs
hf_sd2_paths = [
"feature_extractor/preprocessor_config.json",
"scheduler/scheduler_config.json",
"text_encoder/config.json",
"text_encoder/model.fp16.safetensors",
"text_encoder/model.safetensors",
"text_encoder/pytorch_model.fp16.bin",
"text_encoder/pytorch_model.bin",
"tokenizer/merges.txt",
"tokenizer/special_tokens_map.json",
"tokenizer/tokenizer_config.json",
"tokenizer/vocab.json",
"unet/config.json",
"unet/diffusion_pytorch_model.fp16.safetensors",
"unet/diffusion_pytorch_model.safetensors",
"vae/config.json",
"vae/diffusion_pytorch_model.fp16.safetensors",
"vae/diffusion_pytorch_model.safetensors",
]
for path in hf_sd2_paths:
url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}"
path = Path(path)
filename = path.name
content = b"This is the content for path " + bytearray(path.as_posix(), "utf-8")
session.mount(
url,
TestAdapter(
content,
status=200,
headers={"Content-Length": len(content), "Content-Disposition": f'filename="{filename}"'},
),
)
def test_basic_queue():
# This is the content of `model_index.json` for stable-diffusion-2-1
model_index_content = b'{"_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.8.0", "feature_extractor": ["transformers", "CLIPImageProcessor"], "requires_safety_checker": false, "safety_checker": [null, null], "scheduler": ["diffusers", "DDIMScheduler"], "text_encoder": ["transformers", "CLIPTextModel"], "tokenizer": ["transformers", "CLIPTokenizer"], "unet": ["diffusers", "UNet2DConditionModel"], "vae": ["diffusers", "AutoencoderKL"]}'
session.mount(
"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/model_index.json",
TestAdapter(
model_index_content,
status=200,
headers={"Content-Length": len(model_index_content), "Content-Disposition": f'filename="model_index.json"'},
),
)
########################################################################################3
def test_basic_queue_download():
events = list()
def event_handler(job: DownloadJobBase):
@ -45,20 +105,18 @@ def test_basic_queue():
event_handlers=[event_handler],
)
with tempfile.TemporaryDirectory() as tmpdir:
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
assert isinstance(id1, int), "expected first job id to be numeric"
job = queue.id_to_job(id1)
assert isinstance(job, DownloadJobBase), "expected job to be a DownloadJobBase"
job = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
assert isinstance(job, DownloadJobBase), "expected the job to be of type DownloadJobBase"
assert isinstance(job.id, int), "expected the job id to be numeric"
assert job.status == "idle", "expected job status to be idle"
assert job.status == DownloadJobStatus.IDLE
queue.start_job(id1)
queue.start_job(job)
queue.join()
assert events[0] == DownloadJobStatus.ENQUEUED
assert events[-1] == DownloadJobStatus.COMPLETED
assert DownloadJobStatus.RUNNING in events
assert Path(tmpdir, "mock1.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
assert Path(tmpdir, "mock12345.safetensors").exists(), f"expected {tmpdir}/mock12345.safetensors to exist"
def test_queue_priority():
@ -67,31 +125,166 @@ def test_queue_priority():
)
with tempfile.TemporaryDirectory() as tmpdir:
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
id2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
queue.change_priority(id1, -10) # make id1 run first
job1 = queue.id_to_job(id1)
job2 = queue.id_to_job(id2)
queue.change_priority(job1, -10) # make id1 run first
assert job1 < job2
queue.start_all_jobs()
queue.join()
assert job1.job_sequence < job2.job_sequence
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
id2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
queue.change_priority(id2, -10) # make id2 run first
job1 = queue.id_to_job(id1)
job2 = queue.id_to_job(id2)
queue.change_priority(job2, -10) # make id2 run first
assert job2 < job1
queue.start_all_jobs()
queue.join()
assert job2.job_sequence < job1.job_sequence
assert Path(tmpdir, "mock1.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
assert Path(tmpdir, "mock2.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
assert Path(tmpdir, "mock12345.safetensors").exists(), f"expected {tmpdir}/mock12345.safetensors to exist"
assert Path(tmpdir, "mock9999.safetensors").exists(), f"expected {tmpdir}/mock9999.safetensors to exist"
def test_repo_id_download():
if not INTERNET_AVAILABLE:
return
repo_id = "stabilityai/stable-diffusion-2-1"
queue = DownloadQueue(
requests_session=session,
)
# first with fp16 variant
with tempfile.TemporaryDirectory() as tmpdir:
queue.create_download_job(source=repo_id, destdir=tmpdir, variant="fp16", start=True)
queue.join()
repo_root = Path(tmpdir, "stable-diffusion-2-1")
assert repo_root.exists()
assert Path(repo_root, "model_index.json").exists()
assert Path(repo_root, "text_encoder", "config.json").exists()
assert Path(repo_root, "text_encoder", "model.fp16.safetensors").exists()
# then without fp16
with tempfile.TemporaryDirectory() as tmpdir:
queue.create_download_job(source=repo_id, destdir=tmpdir, start=True)
queue.join()
repo_root = Path(tmpdir, "stable-diffusion-2-1")
assert Path(repo_root, "text_encoder", "model.safetensors").exists()
assert not Path(repo_root, "text_encoder", "model.fp16.safetensors").exists()
def test_failure_modes():
queue = DownloadQueue(
requests_session=session,
)
with tempfile.TemporaryDirectory() as tmpdir:
job = queue.create_download_job(source="http://www.civitai.com/models/broken", destdir=tmpdir)
queue.join()
assert job.status == "error"
assert isinstance(job.error, HTTPError)
assert str(job.error) == "NOT FOUND"
# create a foreign job which will be invalid for the queue
bad_job = DownloadJobBase(id=999, source="mock", destination="mock")
try:
queue.start_job(bad_job) # this should fail
succeeded = True
except UnknownJobIDException:
succeeded = False
assert not succeeded
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase):
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
with tempfile.TemporaryDirectory() as tmpdir:
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
job3 = queue.create_download_job(source="http://www.civitai.com/models/54321", destdir=tmpdir, start=False)
assert job1.status == "idle"
queue.start_job(job1)
queue.start_job(job3)
time.sleep(0.1) # wait for enqueueing
assert job1.status in ["enqueued", "running"]
# check pause and restart
queue.pause_job(job1)
time.sleep(0.1) # wait to be paused
assert job1.status == "paused"
queue.start_job(job1)
time.sleep(0.1)
assert job1.status == "running"
# check cancel
queue.start_job(job2)
time.sleep(0.1)
assert job2.status == "running"
queue.cancel_job(job2)
time.sleep(0.1)
assert job2.status == "cancelled"
queue.join()
assert job1.status == "completed"
assert job2.status == "cancelled"
assert job3.status == "completed"
assert Path(tmpdir, "mock12345.safetensors").exists()
assert Path(tmpdir, "mock9999.safetensors").exists() is False, "cancelled file should be deleted"
assert Path(tmpdir, "mock54321.safetensors").exists()
assert len(queue.list_jobs()) == 0
def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase):
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
if not INTERNET_AVAILABLE:
return
repo_id = "stabilityai/stable-diffusion-2-1"
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)
job2 = queue.create_download_job(source=repo_id, destdir=tmpdir2, variant="fp16", start=False)
assert job1.status == "idle"
queue.start_job(job1)
time.sleep(0.1) # wait for enqueueing
assert job1.status in ["enqueued", "running"]
# check pause and restart
queue.pause_job(job1)
time.sleep(0.1) # wait to be paused
assert job1.status == "paused"
queue.start_job(job1)
time.sleep(0.1)
assert job1.status == "running"
# check cancel
queue.start_job(job2)
time.sleep(0.1)
assert job2.status == "running"
queue.cancel_job(job2)
time.sleep(0.1)
assert job2.status == "cancelled"
queue.join()
assert job1.status == "completed"
assert job2.status == "cancelled"
assert Path(tmpdir1, "stable-diffusion-2-1", "model_index.json").exists()
assert not Path(
tmpdir2, "stable-diffusion-2-1", "model_index.json"
).exists(), "cancelled file should be deleted"
assert len(queue.list_jobs()) == 0