mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
implement regression tests for pause/cancel/error conditions
This commit is contained in:
@ -22,7 +22,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> int:
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
@ -73,26 +73,26 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, id: int):
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, id: int):
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, id: int):
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, id: int, delta: int):
|
||||
def change_priority(self, job: DownloadJobBase, delta: int):
|
||||
"""
|
||||
Change the job's priority.
|
||||
|
||||
:param id: ID of the job
|
||||
:param job: Job to apply change to
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
@ -132,7 +132,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
start: bool = True,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> int:
|
||||
) -> DownloadJobBase:
|
||||
event_handlers = event_handlers or []
|
||||
if self._event_bus:
|
||||
event_handlers.append([self._event_bus.emit_model_download_event])
|
||||
@ -160,16 +160,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def cancel_all_jobs(self):
|
||||
return self._queue.cancel_all_jobs()
|
||||
|
||||
def start_job(self, id: int):
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
return self._queue.start_job(id)
|
||||
|
||||
def pause_job(self, id: int):
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
return self._queue.pause_job(id)
|
||||
|
||||
def cancel_job(self, id: int):
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
return self._queue.cancel_job(id)
|
||||
|
||||
def change_priority(self, id: int, delta: int):
|
||||
def change_priority(self, job: DownloadJobBase, delta: int):
|
||||
return self._queue.change_priority(id, delta)
|
||||
|
||||
def join(self):
|
||||
|
@ -5,7 +5,6 @@ from .base import ( # noqa F401
|
||||
DownloadJobStatus,
|
||||
DownloadEventHandler,
|
||||
UnknownJobIDException,
|
||||
CancelledJobException,
|
||||
DownloadJobBase,
|
||||
)
|
||||
|
||||
|
@ -20,14 +20,11 @@ class DownloadJobStatus(str, Enum):
|
||||
PAUSED = "paused" # previously started, now paused
|
||||
COMPLETED = "completed" # finished running
|
||||
ERROR = "error" # terminated with an error message
|
||||
CANCELLED = "cancelled" # terminated by caller
|
||||
|
||||
|
||||
class UnknownJobIDException(Exception):
|
||||
"""Raised when an invalid Job ID is requested."""
|
||||
|
||||
|
||||
class CancelledJobException(Exception):
|
||||
"""Raised when a job is cancelled."""
|
||||
"""Raised when an invalid Job is referenced."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||
@ -85,7 +82,7 @@ class DownloadQueueBase(ABC):
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> int:
|
||||
) -> DownloadJobBase:
|
||||
"""
|
||||
Create a download job.
|
||||
|
||||
@ -101,7 +98,7 @@ class DownloadQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def release(self) -> int:
|
||||
def release(self):
|
||||
"""
|
||||
Release resources used by queue.
|
||||
|
||||
@ -127,7 +124,7 @@ class DownloadQueueBase(ABC):
|
||||
:param id: ID of the DownloadJobBase.
|
||||
|
||||
Exceptions:
|
||||
* UnknownJobIDException
|
||||
* UnknownJobException
|
||||
|
||||
Note that once a job is completed, id_to_job() may no longer
|
||||
recognize the job. Call id_to_job() before the job completes
|
||||
@ -152,26 +149,26 @@ class DownloadQueueBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def start_job(self, id: int):
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Start the job putting it into ENQUEUED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def pause_job(self, id: int):
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
"""Pause the job, putting it into PAUSED state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, id: int):
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def change_priority(self, id: int, delta: int):
|
||||
def change_priority(self, job: DownloadJobBase, delta: int):
|
||||
"""
|
||||
Change the job's priority.
|
||||
|
||||
:param id: ID of the job
|
||||
:param job: Job to change
|
||||
:param delta: Value to increment or decrement priority.
|
||||
|
||||
Lower values are higher priority. The default starting value is 10.
|
||||
|
@ -4,10 +4,12 @@
|
||||
import re
|
||||
import os
|
||||
import requests
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
|
||||
from pathlib import Path
|
||||
from requests import HTTPError
|
||||
from typing import Dict, Optional, Set, List, Tuple
|
||||
|
||||
from pydantic import Field, validator, ValidationError
|
||||
@ -22,7 +24,6 @@ from .base import (
|
||||
DownloadJobStatus,
|
||||
DownloadEventHandler,
|
||||
UnknownJobIDException,
|
||||
CancelledJobException,
|
||||
DownloadJobBase,
|
||||
)
|
||||
|
||||
@ -95,7 +96,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
variant: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
) -> int:
|
||||
) -> DownloadJobBase:
|
||||
"""Create a download job and return its ID."""
|
||||
if re.match(r"^[\w-]+/[\w-]+$", source):
|
||||
cls = DownloadJobRepoID
|
||||
@ -119,8 +120,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
finally:
|
||||
self._lock.release()
|
||||
if start:
|
||||
self.start_job(id)
|
||||
return job.id
|
||||
self.start_job(job)
|
||||
return job
|
||||
|
||||
def release(self):
|
||||
"""Signal our threads to exit when queue done."""
|
||||
@ -136,31 +137,31 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""List all the jobs."""
|
||||
return self._jobs.values()
|
||||
|
||||
def change_priority(self, id: int, delta: int):
|
||||
def change_priority(self, job: DownloadJobBase, delta: int):
|
||||
"""Change the priority of a job. Smaller priorities run first."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
job = self._jobs[id]
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
job.priority += delta
|
||||
except KeyError as excp:
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def cancel_job(self, id: str):
|
||||
def cancel_job(self, job: DownloadJobBase):
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
If it is running it will be stopped.
|
||||
job.error will be set to CancelledJobException.
|
||||
job.status will be set to DownloadJobStatus.CANCELLED
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
job = self._jobs[id]
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
job.error = CancelledJobException(f"Job {job.id} cancelled at caller's request")
|
||||
self._update_job_status
|
||||
del self._jobs[job.id]
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
# del self._jobs[job.id]
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
@ -171,16 +172,16 @@ class DownloadQueue(DownloadQueueBase):
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def start_job(self, id: int):
|
||||
def start_job(self, job: DownloadJobBase):
|
||||
"""Enqueue (start) the indicated job."""
|
||||
try:
|
||||
job = self._jobs[id]
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
|
||||
self._queue.put(job)
|
||||
except KeyError as excp:
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def pause_job(self, id: int):
|
||||
def pause_job(self, job: DownloadJobBase):
|
||||
"""
|
||||
Pause (dequeue) the indicated job.
|
||||
|
||||
@ -189,9 +190,9 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
job = self._jobs[id]
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
||||
except KeyError as excp:
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
self._lock.release()
|
||||
@ -200,13 +201,13 @@ class DownloadQueue(DownloadQueueBase):
|
||||
"""Start (enqueue) all jobs that are idle or paused."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for id, job in self._jobs.items():
|
||||
for job in self._jobs.values():
|
||||
if job.status in [DownloadJobStatus.IDLE or DownloadJobStatus.PAUSED]:
|
||||
self.start_job(id)
|
||||
self.start_job(job)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def pause_all_jobs(self, id: int):
|
||||
def pause_all_jobs(self):
|
||||
"""Pause all running jobs."""
|
||||
try:
|
||||
self._lock.acquire()
|
||||
@ -221,11 +222,18 @@ class DownloadQueue(DownloadQueueBase):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
for id, job in self._jobs.items():
|
||||
if job.status in [DownloadJobStatus.RUNNING, DownloadJobStatus.PAUSED, DownloadJobStatus.ENQUEUED]:
|
||||
if not self._in_terminal_state(job):
|
||||
self.cancel_job(id)
|
||||
finally:
|
||||
self._lock.release()
|
||||
|
||||
def _in_terminal_state(self, job: DownloadJobBase):
|
||||
return job.status in [
|
||||
DownloadJobStatus.COMPLETED,
|
||||
DownloadJobStatus.ERROR,
|
||||
DownloadJobStatus.CANCELLED,
|
||||
]
|
||||
|
||||
def _start_workers(self, max_workers: int):
|
||||
"""Start the requested number of worker threads."""
|
||||
for i in range(0, max_workers):
|
||||
@ -247,13 +255,21 @@ class DownloadQueue(DownloadQueueBase):
|
||||
|
||||
if job == STOP_JOB: # marker that queue is done
|
||||
break
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for cancelled or errored jobs
|
||||
|
||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||
if isinstance(job, DownloadJobURL):
|
||||
self._download_with_resume(job)
|
||||
elif isinstance(job, DownloadJobRepoID):
|
||||
self._download_repoid(job)
|
||||
else:
|
||||
raise NotImplementedError(f"Don't know what to do with this job: {job}")
|
||||
|
||||
if job.status == DownloadJobStatus.CANCELLED:
|
||||
self._cleanup_cancelled_job(job)
|
||||
|
||||
if self._in_terminal_state(job):
|
||||
del self._jobs[job.id]
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
def _download_with_resume(self, job: DownloadJobBase):
|
||||
@ -277,33 +293,34 @@ class DownloadQueue(DownloadQueueBase):
|
||||
dest = job.destination
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest.exists():
|
||||
job.bytes = dest.stat().st_size
|
||||
header["Range"] = f"bytes={job.bytes}-"
|
||||
open_mode = "ab"
|
||||
resp = self._requests.get(job.source, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
self._logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||
self._logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
return
|
||||
|
||||
if resp.status_code == 206 or exist_size > 0:
|
||||
self._logger.warning(f"{dest}: partial file found. Resuming...")
|
||||
elif resp.status_code != 200:
|
||||
self._logger.error(f"An error occurred during downloading {dest}: {resp.reason}")
|
||||
else:
|
||||
self._logger.info(f"{dest}: Downloading...")
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
report_delta = job.total_bytes / 100 # report every 1% change
|
||||
last_report_bytes = 0
|
||||
try:
|
||||
if dest.exists():
|
||||
job.bytes = dest.stat().st_size
|
||||
header["Range"] = f"bytes={job.bytes}-"
|
||||
open_mode = "ab"
|
||||
resp = self._requests.get(job.source, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
self._logger.warning("corrupt existing file found. re-downloading")
|
||||
os.remove(dest)
|
||||
exist_size = 0
|
||||
|
||||
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
||||
self._logger.warning(f"{dest}: complete file found. Skipping.")
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
return
|
||||
|
||||
if resp.status_code == 206 or exist_size > 0:
|
||||
self._logger.warning(f"{dest}: partial file found. Resuming...")
|
||||
elif resp.status_code != 200:
|
||||
raise HTTPError(resp.reason)
|
||||
else:
|
||||
self._logger.info(f"{dest}: Downloading...")
|
||||
|
||||
report_delta = job.total_bytes / 100 # report every 1% change
|
||||
last_report_bytes = 0
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
with open(dest, open_mode) as file:
|
||||
for data in resp.iter_content(chunk_size=16384):
|
||||
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
|
||||
@ -314,7 +331,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job)
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
del self._jobs[job.id]
|
||||
except Exception as excp:
|
||||
self._logger.error(f"An error occurred while downloading {dest}: {str(excp)}")
|
||||
job.error = excp
|
||||
@ -355,7 +371,10 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||
return
|
||||
|
||||
subqueue = self.__class__(event_handlers=[subdownload_event])
|
||||
subqueue = self.__class__(
|
||||
event_handlers=[subdownload_event],
|
||||
requests_session=self._requests,
|
||||
)
|
||||
try:
|
||||
repo_id = job.source
|
||||
variant = job.variant
|
||||
@ -431,3 +450,11 @@ class DownloadQueue(DownloadQueueBase):
|
||||
for v in basenames.values():
|
||||
result.add(v)
|
||||
return result
|
||||
|
||||
def _cleanup_cancelled_job(self, job: DownloadJobBase):
|
||||
self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}")
|
||||
dest = Path(job.destination)
|
||||
if dest.is_file():
|
||||
dest.unlink()
|
||||
elif dest.is_dir():
|
||||
shutil.rmtree(dest.as_posix(), ignore_errors=True)
|
||||
|
@ -1,40 +1,100 @@
|
||||
"""Test the queued download facility"""
|
||||
|
||||
import tempfile
|
||||
import time
|
||||
import requests
|
||||
from requests_testadapter import TestAdapter
|
||||
from requests import HTTPError
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.backend.model_manager.download import (
|
||||
DownloadJobStatus,
|
||||
DownloadQueue,
|
||||
DownloadJobBase,
|
||||
UnknownJobIDException,
|
||||
)
|
||||
|
||||
TestAdapter.__test__ = False
|
||||
|
||||
SAFETENSORS1_CONTENT = b"I am a safetensors file (1)"
|
||||
SAFETENSORS1_HEADER = {
|
||||
"Content-Length": len(SAFETENSORS1_CONTENT),
|
||||
"Content-Disposition": 'filename="mock1.safetensors"',
|
||||
}
|
||||
SAFETENSORS2_CONTENT = b"I am a safetensors file (2)"
|
||||
SAFETENSORS2_HEADER = {
|
||||
"Content-Length": len(SAFETENSORS2_CONTENT),
|
||||
"Content-Disposition": 'filename="mock2.safetensors"',
|
||||
}
|
||||
INTERNET_AVAILABLE = requests.get("http://www.google.com/").status_code == 200
|
||||
|
||||
########################################################################################
|
||||
# Lots of dummy content here to test model download without using lots of bandwidth
|
||||
# The repo_id tests are not self-contained because they still need to use the HF API
|
||||
# to retrieve metainformation about the files to retrieve. However, the big weights files
|
||||
# are not downloaded.
|
||||
|
||||
# If the internet is not available, then the repo_id tests are skipped, but the single
|
||||
# URL tests are still run.
|
||||
|
||||
session = requests.Session()
|
||||
session.mount(
|
||||
"http://www.civitai.com/models/12345", TestAdapter(SAFETENSORS1_CONTENT, status=200, headers=SAFETENSORS1_HEADER)
|
||||
)
|
||||
session.mount(
|
||||
"http://www.civitai.com/models/9999", TestAdapter(SAFETENSORS2_CONTENT, status=200, headers=SAFETENSORS2_HEADER)
|
||||
)
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
session.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
# mock HuggingFace URLs
|
||||
hf_sd2_paths = [
|
||||
"feature_extractor/preprocessor_config.json",
|
||||
"scheduler/scheduler_config.json",
|
||||
"text_encoder/config.json",
|
||||
"text_encoder/model.fp16.safetensors",
|
||||
"text_encoder/model.safetensors",
|
||||
"text_encoder/pytorch_model.fp16.bin",
|
||||
"text_encoder/pytorch_model.bin",
|
||||
"tokenizer/merges.txt",
|
||||
"tokenizer/special_tokens_map.json",
|
||||
"tokenizer/tokenizer_config.json",
|
||||
"tokenizer/vocab.json",
|
||||
"unet/config.json",
|
||||
"unet/diffusion_pytorch_model.fp16.safetensors",
|
||||
"unet/diffusion_pytorch_model.safetensors",
|
||||
"vae/config.json",
|
||||
"vae/diffusion_pytorch_model.fp16.safetensors",
|
||||
"vae/diffusion_pytorch_model.safetensors",
|
||||
]
|
||||
for path in hf_sd2_paths:
|
||||
url = f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/{path}"
|
||||
path = Path(path)
|
||||
filename = path.name
|
||||
content = b"This is the content for path " + bytearray(path.as_posix(), "utf-8")
|
||||
session.mount(
|
||||
url,
|
||||
TestAdapter(
|
||||
content,
|
||||
status=200,
|
||||
headers={"Content-Length": len(content), "Content-Disposition": f'filename="{filename}"'},
|
||||
),
|
||||
)
|
||||
|
||||
def test_basic_queue():
|
||||
# This is the content of `model_index.json` for stable-diffusion-2-1
|
||||
model_index_content = b'{"_class_name": "StableDiffusionPipeline", "_diffusers_version": "0.8.0", "feature_extractor": ["transformers", "CLIPImageProcessor"], "requires_safety_checker": false, "safety_checker": [null, null], "scheduler": ["diffusers", "DDIMScheduler"], "text_encoder": ["transformers", "CLIPTextModel"], "tokenizer": ["transformers", "CLIPTokenizer"], "unet": ["diffusers", "UNet2DConditionModel"], "vae": ["diffusers", "AutoencoderKL"]}'
|
||||
|
||||
session.mount(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/model_index.json",
|
||||
TestAdapter(
|
||||
model_index_content,
|
||||
status=200,
|
||||
headers={"Content-Length": len(model_index_content), "Content-Disposition": f'filename="model_index.json"'},
|
||||
),
|
||||
)
|
||||
|
||||
########################################################################################3
|
||||
|
||||
|
||||
def test_basic_queue_download():
|
||||
events = list()
|
||||
|
||||
def event_handler(job: DownloadJobBase):
|
||||
@ -45,20 +105,18 @@ def test_basic_queue():
|
||||
event_handlers=[event_handler],
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
assert isinstance(id1, int), "expected first job id to be numeric"
|
||||
|
||||
job = queue.id_to_job(id1)
|
||||
assert isinstance(job, DownloadJobBase), "expected job to be a DownloadJobBase"
|
||||
job = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
assert isinstance(job, DownloadJobBase), "expected the job to be of type DownloadJobBase"
|
||||
assert isinstance(job.id, int), "expected the job id to be numeric"
|
||||
assert job.status == "idle", "expected job status to be idle"
|
||||
assert job.status == DownloadJobStatus.IDLE
|
||||
|
||||
queue.start_job(id1)
|
||||
queue.start_job(job)
|
||||
queue.join()
|
||||
assert events[0] == DownloadJobStatus.ENQUEUED
|
||||
assert events[-1] == DownloadJobStatus.COMPLETED
|
||||
assert DownloadJobStatus.RUNNING in events
|
||||
assert Path(tmpdir, "mock1.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
assert Path(tmpdir, "mock12345.safetensors").exists(), f"expected {tmpdir}/mock12345.safetensors to exist"
|
||||
|
||||
|
||||
def test_queue_priority():
|
||||
@ -67,31 +125,166 @@ def test_queue_priority():
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||
|
||||
id2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||
|
||||
queue.change_priority(id1, -10) # make id1 run first
|
||||
job1 = queue.id_to_job(id1)
|
||||
job2 = queue.id_to_job(id2)
|
||||
queue.change_priority(job1, -10) # make id1 run first
|
||||
assert job1 < job2
|
||||
|
||||
queue.start_all_jobs()
|
||||
queue.join()
|
||||
assert job1.job_sequence < job2.job_sequence
|
||||
|
||||
id1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||
|
||||
id2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||
|
||||
queue.change_priority(id2, -10) # make id2 run first
|
||||
job1 = queue.id_to_job(id1)
|
||||
job2 = queue.id_to_job(id2)
|
||||
queue.change_priority(job2, -10) # make id2 run first
|
||||
assert job2 < job1
|
||||
|
||||
queue.start_all_jobs()
|
||||
queue.join()
|
||||
assert job2.job_sequence < job1.job_sequence
|
||||
|
||||
assert Path(tmpdir, "mock1.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
assert Path(tmpdir, "mock2.safetensors").exists(), f"expected {tmpdir}/mock1.safetensors to exist"
|
||||
assert Path(tmpdir, "mock12345.safetensors").exists(), f"expected {tmpdir}/mock12345.safetensors to exist"
|
||||
assert Path(tmpdir, "mock9999.safetensors").exists(), f"expected {tmpdir}/mock9999.safetensors to exist"
|
||||
|
||||
|
||||
def test_repo_id_download():
|
||||
if not INTERNET_AVAILABLE:
|
||||
return
|
||||
repo_id = "stabilityai/stable-diffusion-2-1"
|
||||
queue = DownloadQueue(
|
||||
requests_session=session,
|
||||
)
|
||||
|
||||
# first with fp16 variant
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
queue.create_download_job(source=repo_id, destdir=tmpdir, variant="fp16", start=True)
|
||||
queue.join()
|
||||
repo_root = Path(tmpdir, "stable-diffusion-2-1")
|
||||
assert repo_root.exists()
|
||||
assert Path(repo_root, "model_index.json").exists()
|
||||
assert Path(repo_root, "text_encoder", "config.json").exists()
|
||||
assert Path(repo_root, "text_encoder", "model.fp16.safetensors").exists()
|
||||
|
||||
# then without fp16
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
queue.create_download_job(source=repo_id, destdir=tmpdir, start=True)
|
||||
queue.join()
|
||||
repo_root = Path(tmpdir, "stable-diffusion-2-1")
|
||||
assert Path(repo_root, "text_encoder", "model.safetensors").exists()
|
||||
assert not Path(repo_root, "text_encoder", "model.fp16.safetensors").exists()
|
||||
|
||||
|
||||
def test_failure_modes():
|
||||
queue = DownloadQueue(
|
||||
requests_session=session,
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
job = queue.create_download_job(source="http://www.civitai.com/models/broken", destdir=tmpdir)
|
||||
queue.join()
|
||||
assert job.status == "error"
|
||||
assert isinstance(job.error, HTTPError)
|
||||
assert str(job.error) == "NOT FOUND"
|
||||
|
||||
# create a foreign job which will be invalid for the queue
|
||||
bad_job = DownloadJobBase(id=999, source="mock", destination="mock")
|
||||
try:
|
||||
queue.start_job(bad_job) # this should fail
|
||||
succeeded = True
|
||||
except UnknownJobIDException:
|
||||
succeeded = False
|
||||
assert not succeeded
|
||||
|
||||
|
||||
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
|
||||
def event_handler(job: DownloadJobBase):
|
||||
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
|
||||
|
||||
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||
job3 = queue.create_download_job(source="http://www.civitai.com/models/54321", destdir=tmpdir, start=False)
|
||||
|
||||
assert job1.status == "idle"
|
||||
queue.start_job(job1)
|
||||
queue.start_job(job3)
|
||||
time.sleep(0.1) # wait for enqueueing
|
||||
assert job1.status in ["enqueued", "running"]
|
||||
|
||||
# check pause and restart
|
||||
queue.pause_job(job1)
|
||||
time.sleep(0.1) # wait to be paused
|
||||
assert job1.status == "paused"
|
||||
|
||||
queue.start_job(job1)
|
||||
time.sleep(0.1)
|
||||
assert job1.status == "running"
|
||||
|
||||
# check cancel
|
||||
queue.start_job(job2)
|
||||
time.sleep(0.1)
|
||||
assert job2.status == "running"
|
||||
queue.cancel_job(job2)
|
||||
time.sleep(0.1)
|
||||
assert job2.status == "cancelled"
|
||||
|
||||
queue.join()
|
||||
assert job1.status == "completed"
|
||||
assert job2.status == "cancelled"
|
||||
assert job3.status == "completed"
|
||||
|
||||
assert Path(tmpdir, "mock12345.safetensors").exists()
|
||||
assert Path(tmpdir, "mock9999.safetensors").exists() is False, "cancelled file should be deleted"
|
||||
assert Path(tmpdir, "mock54321.safetensors").exists()
|
||||
|
||||
assert len(queue.list_jobs()) == 0
|
||||
|
||||
def test_pause_cancel_repo_id(): # this one is tricky because of potential race conditions
|
||||
def event_handler(job: DownloadJobBase):
|
||||
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
|
||||
|
||||
if not INTERNET_AVAILABLE:
|
||||
return
|
||||
|
||||
repo_id = "stabilityai/stable-diffusion-2-1"
|
||||
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
|
||||
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)
|
||||
job2 = queue.create_download_job(source=repo_id, destdir=tmpdir2, variant="fp16", start=False)
|
||||
assert job1.status == "idle"
|
||||
queue.start_job(job1)
|
||||
time.sleep(0.1) # wait for enqueueing
|
||||
assert job1.status in ["enqueued", "running"]
|
||||
|
||||
# check pause and restart
|
||||
queue.pause_job(job1)
|
||||
time.sleep(0.1) # wait to be paused
|
||||
assert job1.status == "paused"
|
||||
|
||||
queue.start_job(job1)
|
||||
time.sleep(0.1)
|
||||
assert job1.status == "running"
|
||||
|
||||
# check cancel
|
||||
queue.start_job(job2)
|
||||
time.sleep(0.1)
|
||||
assert job2.status == "running"
|
||||
queue.cancel_job(job2)
|
||||
time.sleep(0.1)
|
||||
assert job2.status == "cancelled"
|
||||
|
||||
queue.join()
|
||||
assert job1.status == "completed"
|
||||
assert job2.status == "cancelled"
|
||||
|
||||
assert Path(tmpdir1, "stable-diffusion-2-1", "model_index.json").exists()
|
||||
assert not Path(
|
||||
tmpdir2, "stable-diffusion-2-1", "model_index.json"
|
||||
).exists(), "cancelled file should be deleted"
|
||||
|
||||
assert len(queue.list_jobs()) == 0
|
||||
|
Reference in New Issue
Block a user