mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
462 lines
18 KiB
Python
462 lines
18 KiB
Python
# Copyright (c) 2023, Lincoln D. Stein
|
|
"""Implementation of multithreaded download queue for invokeai."""
|
|
|
|
import os
|
|
import re
|
|
import shutil
|
|
import threading
|
|
import time
|
|
import traceback
|
|
from pathlib import Path
|
|
from queue import PriorityQueue
|
|
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
|
|
|
import requests
|
|
from huggingface_hub import HfApi, hf_hub_url
|
|
from pydantic import Field, parse_obj_as, validator
|
|
from pydantic.networks import AnyHttpUrl
|
|
from requests import HTTPError
|
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
|
from invokeai.backend.util import InvokeAILogger, Logger
|
|
|
|
from ..storage import DuplicateModelException
|
|
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase, UnknownJobIDException
|
|
|
|
# Maximum number of bytes to download during each call to requests.iter_content()
|
|
DOWNLOAD_CHUNK_SIZE = 100000
|
|
|
|
# marker that the queue is done and that thread should exit
|
|
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
|
|
|
|
# regular expression for picking up a URL
|
|
HTTP_RE = r"^https?://"
|
|
|
|
|
|
class DownloadJobPath(DownloadJobBase):
|
|
"""Download from a local Path."""
|
|
|
|
source: Path = Field(description="Local filesystem Path where model can be found")
|
|
|
|
|
|
class DownloadJobRemoteSource(DownloadJobBase):
|
|
"""A DownloadJob from a remote source that provides progress info."""
|
|
|
|
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
|
total_bytes: int = Field(default=0, description="Total bytes to download")
|
|
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
|
|
|
|
|
class DownloadJobURL(DownloadJobRemoteSource):
|
|
"""Job declaration for downloading individual URLs."""
|
|
|
|
source: AnyHttpUrl = Field(description="URL to download")
|
|
|
|
|
|
class DownloadQueue(DownloadQueueBase):
|
|
"""Class for queued download of models."""
|
|
|
|
_jobs: Dict[int, DownloadJobBase]
|
|
_worker_pool: Set[threading.Thread]
|
|
_queue: PriorityQueue
|
|
_lock: threading.RLock
|
|
_logger: Logger
|
|
_event_handlers: List[DownloadEventHandler] = Field(default_factory=list)
|
|
_next_job_id: int = 0
|
|
_sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order
|
|
_requests: requests.sessions.Session
|
|
_quiet: bool = False
|
|
|
|
def __init__(
|
|
self,
|
|
max_parallel_dl: int = 5,
|
|
event_handlers: List[DownloadEventHandler] = [],
|
|
requests_session: Optional[requests.sessions.Session] = None,
|
|
config: Optional[InvokeAIAppConfig] = None,
|
|
quiet: bool = False,
|
|
):
|
|
"""
|
|
Initialize DownloadQueue.
|
|
|
|
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
|
:param event_handler: Optional callable that will be called each time a job status changes.
|
|
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
|
"""
|
|
self._jobs = dict()
|
|
self._next_job_id = 0
|
|
self._queue = PriorityQueue()
|
|
self._worker_pool = set()
|
|
self._lock = threading.RLock()
|
|
self._logger = InvokeAILogger.get_logger(config=config)
|
|
self._event_handlers = event_handlers
|
|
self._requests = requests_session or requests.Session()
|
|
self._quiet = quiet
|
|
|
|
self._start_workers(max_parallel_dl)
|
|
|
|
def create_download_job(
|
|
self,
|
|
source: Union[str, Path, AnyHttpUrl],
|
|
destdir: Path,
|
|
start: bool = True,
|
|
priority: int = 10,
|
|
filename: Optional[Path] = None,
|
|
variant: Optional[str] = None,
|
|
access_token: Optional[str] = None,
|
|
event_handlers: List[DownloadEventHandler] = [],
|
|
) -> DownloadJobBase:
|
|
"""Create a download job and return its ID."""
|
|
kwargs: Dict[str, Optional[str]] = dict()
|
|
|
|
cls = DownloadJobBase
|
|
if Path(source).exists():
|
|
cls = DownloadJobPath
|
|
elif re.match(HTTP_RE, str(source)):
|
|
cls = DownloadJobURL
|
|
kwargs.update(access_token=access_token)
|
|
else:
|
|
raise NotImplementedError(f"Don't know what to do with this type of source: {source}")
|
|
|
|
job = cls(
|
|
source=source,
|
|
destination=Path(destdir) / (filename or "."),
|
|
event_handlers=event_handlers,
|
|
priority=priority,
|
|
**kwargs,
|
|
)
|
|
|
|
return self.submit_download_job(job, start)
|
|
|
|
def submit_download_job(
|
|
self,
|
|
job: DownloadJobBase,
|
|
start: bool = True,
|
|
):
|
|
"""Submit a job."""
|
|
# add the queue's handlers
|
|
for handler in self._event_handlers:
|
|
job.add_event_handler(handler)
|
|
try:
|
|
self._lock.acquire()
|
|
job.id = self._next_job_id
|
|
self._jobs[job.id] = job
|
|
self._next_job_id += 1
|
|
finally:
|
|
self._lock.release()
|
|
if start:
|
|
self.start_job(job)
|
|
return job
|
|
|
|
def release(self):
|
|
"""Signal our threads to exit when queue done."""
|
|
for thread in self._worker_pool:
|
|
if thread.is_alive():
|
|
self._queue.put(STOP_JOB)
|
|
|
|
def join(self):
|
|
"""Wait for all jobs to complete."""
|
|
self._queue.join()
|
|
|
|
def list_jobs(self) -> List[DownloadJobBase]:
|
|
"""List all the jobs."""
|
|
return list(self._jobs.values())
|
|
|
|
def change_priority(self, job: DownloadJobBase, delta: int):
|
|
"""Change the priority of a job. Smaller priorities run first."""
|
|
try:
|
|
self._lock.acquire()
|
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
|
job.priority += delta
|
|
except (AssertionError, KeyError) as excp:
|
|
raise UnknownJobIDException("Unrecognized job") from excp
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def prune_jobs(self):
|
|
"""Prune completed and errored queue items from the job list."""
|
|
try:
|
|
to_delete = set()
|
|
self._lock.acquire()
|
|
for job_id, job in self._jobs.items():
|
|
if self._in_terminal_state(job):
|
|
to_delete.add(job_id)
|
|
for job_id in to_delete:
|
|
del self._jobs[job_id]
|
|
except KeyError as excp:
|
|
raise UnknownJobIDException("Unrecognized job") from excp
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def cancel_job(self, job: DownloadJobBase, preserve_partial: bool = False):
|
|
"""
|
|
Cancel the indicated job.
|
|
|
|
If it is running it will be stopped.
|
|
job.status will be set to DownloadJobStatus.CANCELLED
|
|
"""
|
|
try:
|
|
self._lock.acquire()
|
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
|
job.preserve_partial_downloads = preserve_partial
|
|
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
|
job.cleanup()
|
|
except (AssertionError, KeyError) as excp:
|
|
raise UnknownJobIDException("Unrecognized job") from excp
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def id_to_job(self, id: int) -> DownloadJobBase:
|
|
"""Translate a job ID into a DownloadJobBase object."""
|
|
try:
|
|
return self._jobs[id]
|
|
except KeyError as excp:
|
|
raise UnknownJobIDException("Unrecognized job") from excp
|
|
|
|
def start_job(self, job: DownloadJobBase):
|
|
"""Enqueue (start) the indicated job."""
|
|
try:
|
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
|
self._update_job_status(job, DownloadJobStatus.ENQUEUED)
|
|
self._queue.put(job)
|
|
except (AssertionError, KeyError) as excp:
|
|
raise UnknownJobIDException("Unrecognized job") from excp
|
|
|
|
def pause_job(self, job: DownloadJobBase):
|
|
"""
|
|
Pause (dequeue) the indicated job.
|
|
|
|
The job can be restarted with start_job() and the download will pick up
|
|
from where it left off.
|
|
"""
|
|
try:
|
|
self._lock.acquire()
|
|
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
|
self._update_job_status(job, DownloadJobStatus.PAUSED)
|
|
job.cleanup()
|
|
except (AssertionError, KeyError) as excp:
|
|
raise UnknownJobIDException("Unrecognized job") from excp
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def start_all_jobs(self):
|
|
"""Start (enqueue) all jobs that are idle or paused."""
|
|
try:
|
|
self._lock.acquire()
|
|
for job in self._jobs.values():
|
|
if job.status in [DownloadJobStatus.IDLE, DownloadJobStatus.PAUSED]:
|
|
self.start_job(job)
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def pause_all_jobs(self):
|
|
"""Pause all running jobs."""
|
|
try:
|
|
self._lock.acquire()
|
|
for job in self._jobs.values():
|
|
if job.status == DownloadJobStatus.RUNNING:
|
|
self.pause_job(job)
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def cancel_all_jobs(self, preserve_partial: bool = False):
|
|
"""Cancel all running jobs."""
|
|
try:
|
|
self._lock.acquire()
|
|
for job in self._jobs.values():
|
|
if not self._in_terminal_state(job):
|
|
self.cancel_job(job, preserve_partial)
|
|
finally:
|
|
self._lock.release()
|
|
|
|
def _in_terminal_state(self, job: DownloadJobBase):
|
|
return job.status in [
|
|
DownloadJobStatus.COMPLETED,
|
|
DownloadJobStatus.ERROR,
|
|
DownloadJobStatus.CANCELLED,
|
|
]
|
|
|
|
def _start_workers(self, max_workers: int):
|
|
"""Start the requested number of worker threads."""
|
|
for i in range(0, max_workers):
|
|
worker = threading.Thread(target=self._download_next_item, daemon=True)
|
|
worker.start()
|
|
self._worker_pool.add(worker)
|
|
|
|
def _download_next_item(self):
|
|
"""Worker thread gets next job on priority queue."""
|
|
done = False
|
|
while not done:
|
|
job = self._queue.get()
|
|
|
|
try: # this is for debugging priority
|
|
self._lock.acquire()
|
|
job.job_sequence = self._sequence
|
|
self._sequence += 1
|
|
finally:
|
|
self._lock.release()
|
|
|
|
if job == STOP_JOB: # marker that queue is done
|
|
done = True
|
|
|
|
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
|
if not self._quiet:
|
|
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
|
do_download = self.select_downloader(job)
|
|
do_download(job)
|
|
|
|
if job.status == DownloadJobStatus.CANCELLED:
|
|
self._cleanup_cancelled_job(job)
|
|
|
|
self._queue.task_done()
|
|
|
|
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
|
"""Based on the job type select the download method."""
|
|
if isinstance(job, DownloadJobURL):
|
|
return self._download_with_resume
|
|
elif isinstance(job, DownloadJobPath):
|
|
return self._download_path
|
|
else:
|
|
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
|
|
|
|
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
|
|
return job.source
|
|
|
|
def _download_with_resume(self, job: DownloadJobBase):
|
|
"""Do the actual download."""
|
|
dest = None
|
|
try:
|
|
assert isinstance(job, DownloadJobRemoteSource)
|
|
url = self.get_url_for_job(job)
|
|
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
|
open_mode = "wb"
|
|
exist_size = 0
|
|
|
|
resp = self._requests.get(url, headers=header, stream=True)
|
|
content_length = int(resp.headers.get("content-length", 0))
|
|
job.total_bytes = content_length
|
|
|
|
if job.destination.is_dir():
|
|
try:
|
|
file_name = ""
|
|
if match := re.search('filename="(.+)"', resp.headers["Content-Disposition"]):
|
|
file_name = match.group(1)
|
|
assert file_name != ""
|
|
self._validate_filename(
|
|
job.destination.as_posix(), file_name
|
|
) # will raise a ValueError exception if file_name is suspicious
|
|
except ValueError:
|
|
self._logger.warning(
|
|
f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead"
|
|
)
|
|
file_name = os.path.basename(url)
|
|
except (KeyError, AssertionError):
|
|
file_name = os.path.basename(url)
|
|
job.destination = job.destination / file_name
|
|
dest = job.destination
|
|
else:
|
|
dest = job.destination
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if dest.exists():
|
|
job.bytes = dest.stat().st_size
|
|
header["Range"] = f"bytes={job.bytes}-"
|
|
open_mode = "ab"
|
|
resp = self._requests.get(url, headers=header, stream=True) # new request with range
|
|
|
|
if exist_size > content_length:
|
|
self._logger.warning("corrupt existing file found. re-downloading")
|
|
os.remove(dest)
|
|
exist_size = 0
|
|
|
|
if resp.status_code == 416 or (content_length > 0 and exist_size == content_length):
|
|
self._logger.warning(f"{dest}: complete file found. Skipping.")
|
|
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
|
return
|
|
|
|
if resp.status_code == 206 or exist_size > 0:
|
|
self._logger.warning(f"{dest}: partial file found. Resuming")
|
|
elif resp.status_code != 200:
|
|
raise HTTPError(resp.reason)
|
|
else:
|
|
self._logger.debug(f"{job.source}: Downloading {job.destination}")
|
|
|
|
report_delta = job.total_bytes / 100 # report every 1% change
|
|
last_report_bytes = 0
|
|
|
|
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
|
with open(dest, open_mode) as file:
|
|
for data in resp.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
|
|
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
|
|
return
|
|
job.bytes += file.write(data)
|
|
if job.bytes - last_report_bytes >= report_delta:
|
|
last_report_bytes = job.bytes
|
|
self._update_job_status(job)
|
|
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
|
|
return
|
|
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
|
except KeyboardInterrupt as excp:
|
|
raise excp
|
|
except DuplicateModelException as excp:
|
|
self._logger.error(f"A model with the same hash as {dest} is already installed.")
|
|
job.error = excp
|
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
|
except Exception as excp:
|
|
self._logger.error(f"An error occurred while downloading/installing {job.source}: {str(excp)}")
|
|
print(traceback.format_exc())
|
|
job.error = excp
|
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
|
|
|
def _validate_filename(self, directory: str, filename: str):
|
|
pc_name_max = os.pathconf(directory, "PC_NAME_MAX") if hasattr(os, "pathconf") else 260
|
|
if "/" in filename:
|
|
raise ValueError
|
|
if filename.startswith(".."):
|
|
raise ValueError
|
|
if len(filename) > pc_name_max:
|
|
raise ValueError
|
|
if len(os.path.join(directory, filename)) > os.pathconf(directory, "PC_PATH_MAX"):
|
|
raise ValueError
|
|
|
|
def _update_job_status(self, job: DownloadJobBase, new_status: Optional[DownloadJobStatus] = None):
|
|
"""Optionally change the job status and send an event indicating a change of state."""
|
|
if new_status:
|
|
job.status = new_status
|
|
self._logger.debug(f"Status update for download job {job.id}: {job}")
|
|
if new_status == DownloadJobStatus.RUNNING and not job.job_started:
|
|
job.job_started = time.time()
|
|
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
|
|
job.job_ended = time.time()
|
|
if job.event_handlers:
|
|
for handler in job.event_handlers:
|
|
try:
|
|
handler(job)
|
|
except KeyboardInterrupt as excp:
|
|
raise excp
|
|
except Exception as excp:
|
|
job.error = excp
|
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
|
|
|
def _download_path(self, job: DownloadJobBase):
|
|
"""Call when the source is a Path or pathlike object."""
|
|
source = Path(job.source).resolve()
|
|
destination = Path(job.destination).resolve()
|
|
try:
|
|
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
|
if source != destination:
|
|
shutil.move(source, destination)
|
|
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
|
except OSError as excp:
|
|
job.error = excp
|
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
|
|
|
def _cleanup_cancelled_job(self, job: DownloadJobBase):
|
|
if job.preserve_partial_downloads:
|
|
return
|
|
self._logger.warning("Cleaning up leftover files from cancelled download job {job.destination}")
|
|
dest = Path(job.destination)
|
|
if dest.is_file():
|
|
dest.unlink()
|
|
elif dest.is_dir():
|
|
shutil.rmtree(dest.as_posix(), ignore_errors=True)
|