mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
6cc6a45274
Just a bit of typo protection in lieu of full type safety for these methods, which is difficult due to the typing of `DownloadEventHandler`.
599 lines
24 KiB
Python
599 lines
24 KiB
Python
# Copyright (c) 2023, Lincoln D. Stein
|
|
"""Implementation of multithreaded download queue for invokeai."""
|
|
|
|
import os
|
|
import re
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from pathlib import Path
|
|
from queue import Empty, PriorityQueue
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set
|
|
|
|
import requests
|
|
from pydantic.networks import AnyHttpUrl
|
|
from requests import HTTPError
|
|
from tqdm import tqdm
|
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
|
from invokeai.app.services.events.events_base import EventServiceBase
|
|
from invokeai.app.util.misc import get_iso_timestamp
|
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
|
from invokeai.backend.util.logging import InvokeAILogger
|
|
|
|
from .download_base import (
|
|
DownloadEventHandler,
|
|
DownloadExceptionHandler,
|
|
DownloadJob,
|
|
DownloadJobBase,
|
|
DownloadJobCancelledException,
|
|
DownloadJobStatus,
|
|
DownloadQueueServiceBase,
|
|
MultiFileDownloadJob,
|
|
ServiceInactiveException,
|
|
UnknownJobIDException,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from invokeai.app.services.events.events_base import EventServiceBase
|
|
|
|
# Maximum number of bytes to download during each call to requests.iter_content()
|
|
DOWNLOAD_CHUNK_SIZE = 100000
|
|
|
|
|
|
class DownloadQueueService(DownloadQueueServiceBase):
|
|
"""Class for queued download of models."""
|
|
|
|
def __init__(
|
|
self,
|
|
max_parallel_dl: int = 5,
|
|
app_config: Optional[InvokeAIAppConfig] = None,
|
|
event_bus: Optional["EventServiceBase"] = None,
|
|
requests_session: Optional[requests.sessions.Session] = None,
|
|
):
|
|
"""
|
|
Initialize DownloadQueue.
|
|
|
|
:param app_config: InvokeAIAppConfig object
|
|
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
|
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
|
"""
|
|
self._app_config = app_config or get_config()
|
|
self._jobs: Dict[int, DownloadJob] = {}
|
|
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
|
|
self._next_job_id = 0
|
|
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
|
self._stop_event = threading.Event()
|
|
self._job_terminated_event = threading.Event()
|
|
self._worker_pool: Set[threading.Thread] = set()
|
|
self._lock = threading.Lock()
|
|
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
|
self._event_bus = event_bus
|
|
self._requests = requests_session or requests.Session()
|
|
self._accept_download_requests = False
|
|
self._max_parallel_dl = max_parallel_dl
|
|
|
|
def start(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Start the download worker threads."""
|
|
with self._lock:
|
|
if self._worker_pool:
|
|
raise Exception("Attempt to start the download service twice")
|
|
self._stop_event.clear()
|
|
self._start_workers(self._max_parallel_dl)
|
|
self._accept_download_requests = True
|
|
|
|
def stop(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Stop the download worker threads."""
|
|
with self._lock:
|
|
if not self._worker_pool:
|
|
raise Exception("Attempt to stop the download service before it was started")
|
|
self._accept_download_requests = False # reject attempts to add new jobs to queue
|
|
queued_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.WAITING]
|
|
active_jobs = [x for x in self.list_jobs() if x.status == DownloadJobStatus.RUNNING]
|
|
if queued_jobs:
|
|
self._logger.warning(f"Cancelling {len(queued_jobs)} queued downloads")
|
|
if active_jobs:
|
|
self._logger.info(f"Waiting for {len(active_jobs)} active download jobs to complete")
|
|
with self._queue.mutex:
|
|
self._queue.queue.clear()
|
|
self.cancel_all_jobs()
|
|
self._stop_event.set()
|
|
for thread in self._worker_pool:
|
|
thread.join()
|
|
self._worker_pool.clear()
|
|
|
|
def submit_download_job(
|
|
self,
|
|
job: DownloadJob,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadExceptionHandler] = None,
|
|
) -> None:
|
|
"""Enqueue a download job."""
|
|
if not self._accept_download_requests:
|
|
raise ServiceInactiveException(
|
|
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
|
)
|
|
job.id = self._next_id()
|
|
job.set_callbacks(
|
|
on_start=on_start,
|
|
on_progress=on_progress,
|
|
on_complete=on_complete,
|
|
on_cancelled=on_cancelled,
|
|
on_error=on_error,
|
|
)
|
|
self._jobs[job.id] = job
|
|
self._queue.put(job)
|
|
|
|
def download(
|
|
self,
|
|
source: AnyHttpUrl,
|
|
dest: Path,
|
|
priority: int = 10,
|
|
access_token: Optional[str] = None,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadExceptionHandler] = None,
|
|
) -> DownloadJob:
|
|
"""Create and enqueue a download job and return it."""
|
|
if not self._accept_download_requests:
|
|
raise ServiceInactiveException(
|
|
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
|
)
|
|
job = DownloadJob(
|
|
source=source,
|
|
dest=dest,
|
|
priority=priority,
|
|
access_token=access_token or self._lookup_access_token(source),
|
|
)
|
|
self.submit_download_job(
|
|
job,
|
|
on_start=on_start,
|
|
on_progress=on_progress,
|
|
on_complete=on_complete,
|
|
on_cancelled=on_cancelled,
|
|
on_error=on_error,
|
|
)
|
|
return job
|
|
|
|
def multifile_download(
|
|
self,
|
|
parts: List[RemoteModelFile],
|
|
dest: Path,
|
|
access_token: Optional[str] = None,
|
|
submit_job: bool = True,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadExceptionHandler] = None,
|
|
) -> MultiFileDownloadJob:
|
|
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
|
|
mfdj.set_callbacks(
|
|
on_start=on_start,
|
|
on_progress=on_progress,
|
|
on_complete=on_complete,
|
|
on_cancelled=on_cancelled,
|
|
on_error=on_error,
|
|
)
|
|
|
|
for part in parts:
|
|
url = part.url
|
|
path = dest / part.path
|
|
assert path.is_relative_to(dest), "only relative download paths accepted"
|
|
job = DownloadJob(
|
|
source=url,
|
|
dest=path,
|
|
access_token=access_token,
|
|
)
|
|
mfdj.download_parts.add(job)
|
|
self._download_part2parent[job.source] = mfdj
|
|
if submit_job:
|
|
self.submit_multifile_download(mfdj)
|
|
return mfdj
|
|
|
|
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
|
for download_job in job.download_parts:
|
|
self.submit_download_job(
|
|
download_job,
|
|
on_start=self._mfd_started,
|
|
on_progress=self._mfd_progress,
|
|
on_complete=self._mfd_complete,
|
|
on_cancelled=self._mfd_cancelled,
|
|
on_error=self._mfd_error,
|
|
)
|
|
|
|
def join(self) -> None:
|
|
"""Wait for all jobs to complete."""
|
|
self._queue.join()
|
|
|
|
def _next_id(self) -> int:
|
|
with self._lock:
|
|
id = self._next_job_id
|
|
self._next_job_id += 1
|
|
return id
|
|
|
|
def list_jobs(self) -> List[DownloadJob]:
|
|
"""List all the jobs."""
|
|
return list(self._jobs.values())
|
|
|
|
def prune_jobs(self) -> None:
|
|
"""Prune completed and errored queue items from the job list."""
|
|
with self._lock:
|
|
to_delete = set()
|
|
for job_id, job in self._jobs.items():
|
|
if job.in_terminal_state:
|
|
to_delete.add(job_id)
|
|
for job_id in to_delete:
|
|
del self._jobs[job_id]
|
|
|
|
def id_to_job(self, id: int) -> DownloadJob:
|
|
"""Translate a job ID into a DownloadJob object."""
|
|
try:
|
|
return self._jobs[id]
|
|
except KeyError as excp:
|
|
raise UnknownJobIDException("Unrecognized job") from excp
|
|
|
|
def cancel_job(self, job: DownloadJobBase) -> None:
|
|
"""
|
|
Cancel the indicated job.
|
|
|
|
If it is running it will be stopped.
|
|
job.status will be set to DownloadJobStatus.CANCELLED
|
|
"""
|
|
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
|
|
job.cancel()
|
|
|
|
def cancel_all_jobs(self) -> None:
|
|
"""Cancel all jobs (those not in enqueued, running or paused state)."""
|
|
for job in self._jobs.values():
|
|
if not job.in_terminal_state:
|
|
self.cancel_job(job)
|
|
|
|
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
|
start = time.time()
|
|
while not job.in_terminal_state:
|
|
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
|
|
self._job_terminated_event.clear()
|
|
if timeout > 0 and time.time() - start > timeout:
|
|
raise TimeoutError("Timeout exceeded")
|
|
return job
|
|
|
|
def _start_workers(self, max_workers: int) -> None:
|
|
"""Start the requested number of worker threads."""
|
|
self._stop_event.clear()
|
|
for i in range(0, max_workers): # noqa B007
|
|
worker = threading.Thread(target=self._download_next_item, daemon=True)
|
|
self._logger.debug(f"Download queue worker thread {worker.name} starting.")
|
|
worker.start()
|
|
self._worker_pool.add(worker)
|
|
|
|
def _download_next_item(self) -> None:
|
|
"""Worker thread gets next job on priority queue."""
|
|
done = False
|
|
while not done:
|
|
if self._stop_event.is_set():
|
|
done = True
|
|
continue
|
|
try:
|
|
job = self._queue.get(timeout=1)
|
|
except Empty:
|
|
continue
|
|
try:
|
|
job.job_started = get_iso_timestamp()
|
|
self._do_download(job)
|
|
self._signal_job_complete(job)
|
|
except DownloadJobCancelledException:
|
|
self._signal_job_cancelled(job)
|
|
self._cleanup_cancelled_job(job)
|
|
except Exception as excp:
|
|
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
|
job.error = traceback.format_exc()
|
|
self._signal_job_error(job, excp)
|
|
finally:
|
|
job.job_ended = get_iso_timestamp()
|
|
self._job_terminated_event.set() # signal a change to terminal state
|
|
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
|
|
self._job_terminated_event.set()
|
|
self._queue.task_done()
|
|
|
|
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
|
|
|
def _do_download(self, job: DownloadJob) -> None:
|
|
"""Do the actual download."""
|
|
|
|
url = job.source
|
|
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
|
open_mode = "wb"
|
|
|
|
# Make a streaming request. This will retrieve headers including
|
|
# content-length and content-disposition, but not fetch any content itself
|
|
resp = self._requests.get(str(url), headers=header, stream=True)
|
|
if not resp.ok:
|
|
raise HTTPError(resp.reason)
|
|
|
|
job.content_type = resp.headers.get("Content-Type")
|
|
content_length = int(resp.headers.get("content-length", 0))
|
|
job.total_bytes = content_length
|
|
|
|
if job.dest.is_dir():
|
|
file_name = os.path.basename(str(url.path)) # default is to use the last bit of the URL
|
|
|
|
if match := re.search('filename="(.+)"', resp.headers.get("Content-Disposition", "")):
|
|
remote_name = match.group(1)
|
|
if self._validate_filename(job.dest.as_posix(), remote_name):
|
|
file_name = remote_name
|
|
|
|
job.download_path = job.dest / file_name
|
|
|
|
else:
|
|
job.dest.parent.mkdir(parents=True, exist_ok=True)
|
|
job.download_path = job.dest
|
|
|
|
assert job.download_path
|
|
|
|
# Don't clobber an existing file. See commit 82c2c85202f88c6d24ff84710f297cfc6ae174af
|
|
# for code that instead resumes an interrupted download.
|
|
if job.download_path.exists():
|
|
raise OSError(f"[Errno 17] File {job.download_path} exists")
|
|
|
|
# append ".downloading" to the path
|
|
in_progress_path = self._in_progress_path(job.download_path)
|
|
|
|
# signal caller that the download is starting. At this point, key fields such as
|
|
# download_path and total_bytes will be populated. We call it here because the might
|
|
# discover that the local file is already complete and generate a COMPLETED status.
|
|
self._signal_job_started(job)
|
|
|
|
# "range not satisfiable" - local file is at least as large as the remote file
|
|
if resp.status_code == 416 or (content_length > 0 and job.bytes >= content_length):
|
|
self._logger.warning(f"{job.download_path}: complete file found. Skipping.")
|
|
return
|
|
|
|
# "partial content" - local file is smaller than remote file
|
|
elif resp.status_code == 206 or job.bytes > 0:
|
|
self._logger.warning(f"{job.download_path}: partial file found. Resuming")
|
|
|
|
# some other error
|
|
elif resp.status_code != 200:
|
|
raise HTTPError(resp.reason)
|
|
|
|
self._logger.debug(f"{job.source}: Downloading {job.download_path}")
|
|
report_delta = job.total_bytes / 100 # report every 1% change
|
|
last_report_bytes = 0
|
|
|
|
# DOWNLOAD LOOP
|
|
with open(in_progress_path, open_mode) as file:
|
|
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
|
if job.cancelled:
|
|
raise DownloadJobCancelledException("Job was cancelled at caller's request")
|
|
|
|
job.bytes += file.write(data)
|
|
if (job.bytes - last_report_bytes >= report_delta) or (job.bytes >= job.total_bytes):
|
|
last_report_bytes = job.bytes
|
|
self._signal_job_progress(job)
|
|
|
|
# if we get here we are done and can rename the file to the original dest
|
|
self._logger.debug(f"{job.source}: saved to {job.download_path} (bytes={job.bytes})")
|
|
in_progress_path.rename(job.download_path)
|
|
|
|
def _validate_filename(self, directory: str, filename: str) -> bool:
|
|
pc_name_max = get_pc_name_max(directory)
|
|
pc_path_max = get_pc_path_max(directory)
|
|
if "/" in filename:
|
|
return False
|
|
if filename.startswith(".."):
|
|
return False
|
|
if len(filename) > pc_name_max:
|
|
return False
|
|
if len(os.path.join(directory, filename)) > pc_path_max:
|
|
return False
|
|
return True
|
|
|
|
def _in_progress_path(self, path: Path) -> Path:
|
|
return path.with_name(path.name + ".downloading")
|
|
|
|
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
|
# Pull the token from config if it exists and matches the URL
|
|
token = None
|
|
for pair in self._app_config.remote_api_tokens or []:
|
|
if re.search(pair.url_regex, str(source)):
|
|
token = pair.token
|
|
break
|
|
return token
|
|
|
|
def _signal_job_started(self, job: DownloadJob) -> None:
|
|
job.status = DownloadJobStatus.RUNNING
|
|
self._execute_cb(job, "on_start")
|
|
if self._event_bus:
|
|
self._event_bus.emit_download_started(job)
|
|
|
|
def _signal_job_progress(self, job: DownloadJob) -> None:
|
|
self._execute_cb(job, "on_progress")
|
|
if self._event_bus:
|
|
self._event_bus.emit_download_progress(job)
|
|
|
|
def _signal_job_complete(self, job: DownloadJob) -> None:
|
|
job.status = DownloadJobStatus.COMPLETED
|
|
self._execute_cb(job, "on_complete")
|
|
if self._event_bus:
|
|
self._event_bus.emit_download_complete(job)
|
|
|
|
def _signal_job_cancelled(self, job: DownloadJob) -> None:
|
|
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
|
return
|
|
job.status = DownloadJobStatus.CANCELLED
|
|
self._execute_cb(job, "on_cancelled")
|
|
if self._event_bus:
|
|
self._event_bus.emit_download_cancelled(job)
|
|
|
|
# if multifile download, then signal the parent
|
|
if parent_job := self._download_part2parent.get(job.source, None):
|
|
if not parent_job.in_terminal_state:
|
|
parent_job.status = DownloadJobStatus.CANCELLED
|
|
self._execute_cb(parent_job, "on_cancelled")
|
|
|
|
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
|
job.status = DownloadJobStatus.ERROR
|
|
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
|
self._execute_cb(job, "on_error", excp)
|
|
|
|
if self._event_bus:
|
|
self._event_bus.emit_download_error(job)
|
|
|
|
def _cleanup_cancelled_job(self, job: DownloadJob) -> None:
|
|
self._logger.debug(f"Cleaning up leftover files from cancelled download job {job.download_path}")
|
|
try:
|
|
if job.download_path:
|
|
partial_file = self._in_progress_path(job.download_path)
|
|
partial_file.unlink()
|
|
except OSError as excp:
|
|
self._logger.warning(excp)
|
|
|
|
########################################
|
|
# callbacks used for multifile downloads
|
|
########################################
|
|
def _mfd_started(self, download_job: DownloadJob) -> None:
|
|
self._logger.info(f"File download started: {download_job.source}")
|
|
with self._lock:
|
|
mf_job = self._download_part2parent[download_job.source]
|
|
if mf_job.waiting:
|
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
|
mf_job.status = DownloadJobStatus.RUNNING
|
|
assert download_job.download_path is not None
|
|
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
|
|
mf_job.download_path = (
|
|
mf_job.dest / path_relative_to_destdir.parts[0]
|
|
) # keep just the first component of the path
|
|
self._execute_cb(mf_job, "on_start")
|
|
|
|
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
|
with self._lock:
|
|
mf_job = self._download_part2parent[download_job.source]
|
|
if mf_job.cancelled:
|
|
for part in mf_job.download_parts:
|
|
self.cancel_job(part)
|
|
elif mf_job.running:
|
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
|
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
|
self._execute_cb(mf_job, "on_progress")
|
|
|
|
def _mfd_complete(self, download_job: DownloadJob) -> None:
|
|
self._logger.info(f"Download complete: {download_job.source}")
|
|
with self._lock:
|
|
mf_job = self._download_part2parent[download_job.source]
|
|
|
|
# are there any more active jobs left in this task?
|
|
if mf_job.running and all(x.complete for x in mf_job.download_parts):
|
|
mf_job.status = DownloadJobStatus.COMPLETED
|
|
self._execute_cb(mf_job, "on_complete")
|
|
|
|
# we're done with this sub-job
|
|
self._job_terminated_event.set()
|
|
|
|
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
|
|
with self._lock:
|
|
mf_job = self._download_part2parent[download_job.source]
|
|
assert mf_job is not None
|
|
|
|
if not mf_job.in_terminal_state:
|
|
self._logger.warning(f"Download cancelled: {download_job.source}")
|
|
mf_job.cancel()
|
|
|
|
for s in mf_job.download_parts:
|
|
self.cancel_job(s)
|
|
|
|
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
|
with self._lock:
|
|
mf_job = self._download_part2parent[download_job.source]
|
|
assert mf_job is not None
|
|
if not mf_job.in_terminal_state:
|
|
mf_job.status = download_job.status
|
|
mf_job.error = download_job.error
|
|
mf_job.error_type = download_job.error_type
|
|
self._execute_cb(mf_job, "on_error", excp)
|
|
self._logger.error(
|
|
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
|
|
)
|
|
for s in [x for x in mf_job.download_parts if x.running]:
|
|
self.cancel_job(s)
|
|
self._download_part2parent.pop(download_job.source)
|
|
self._job_terminated_event.set()
|
|
|
|
def _execute_cb(
|
|
self,
|
|
job: DownloadJob | MultiFileDownloadJob,
|
|
callback_name: Literal[
|
|
"on_start",
|
|
"on_progress",
|
|
"on_complete",
|
|
"on_cancelled",
|
|
"on_error",
|
|
],
|
|
excp: Optional[Exception] = None,
|
|
) -> None:
|
|
if callback := getattr(job, callback_name, None):
|
|
args = [job, excp] if excp else [job]
|
|
try:
|
|
callback(*args)
|
|
except Exception as e:
|
|
self._logger.error(
|
|
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
|
|
)
|
|
|
|
|
|
def get_pc_name_max(directory: str) -> int:
|
|
if hasattr(os, "pathconf"):
|
|
try:
|
|
return os.pathconf(directory, "PC_NAME_MAX")
|
|
except OSError:
|
|
# macOS w/ external drives raise OSError
|
|
pass
|
|
return 260 # hardcoded for windows
|
|
|
|
|
|
def get_pc_path_max(directory: str) -> int:
|
|
if hasattr(os, "pathconf"):
|
|
try:
|
|
return os.pathconf(directory, "PC_PATH_MAX")
|
|
except OSError:
|
|
# some platforms may not have this value
|
|
pass
|
|
return 32767 # hardcoded for windows with long names enabled
|
|
|
|
|
|
# Example on_progress event handler to display a TQDM status bar
|
|
# Activate with:
|
|
# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update))
|
|
class TqdmProgress(object):
|
|
"""TQDM-based progress bar object to use in on_progress handlers."""
|
|
|
|
_bars: Dict[int, tqdm] # type: ignore
|
|
_last: Dict[int, int] # last bytes downloaded
|
|
|
|
def __init__(self) -> None: # noqa D107
|
|
self._bars = {}
|
|
self._last = {}
|
|
|
|
def update(self, job: DownloadJob) -> None: # noqa D102
|
|
job_id = job.id
|
|
# new job
|
|
if job_id not in self._bars:
|
|
assert job.download_path
|
|
dest = Path(job.download_path).name
|
|
self._bars[job_id] = tqdm(
|
|
desc=dest,
|
|
initial=0,
|
|
total=job.total_bytes,
|
|
unit="iB",
|
|
unit_scale=True,
|
|
)
|
|
self._last[job_id] = 0
|
|
self._bars[job_id].update(job.bytes - self._last[job_id])
|
|
self._last[job_id] = job.bytes
|