mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fbede84405
* add base definition of download manager * basic functionality working * add unit tests for download queue * add documentation and FastAPI route * fix docs * add missing test dependency; fix import ordering * fix file path length checking on windows * fix ruff check error * move release() into the __del__ method * disable testing of stderr messages due to issues with pytest capsys fixture * fix unsorted imports * harmonized implementation of start() and stop() calls in download and & install modules * Update invokeai/app/services/download/download_base.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * replace test datadir fixture with tmp_path * replace DownloadJobBase->DownloadJob in download manager documentation * make source and dest arguments to download_queue.download() an AnyHttpURL and Path respectively * fix pydantic typecheck errors in the download unit test * ruff formatting * add "job cancelled" as an event rather than an exception * fix ruff errors * Update invokeai/app/services/download/download_default.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * use threading.Event to stop service worker threads; handle unfinished job edge cases * remove dangling STOP job definition * fix ruff complaint * fix ruff check again * avoid race condition when start() and stop() are called simultaneously from different threads * avoid race condition in stop() when a job becomes active while shutting down --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
218 lines
7.9 KiB
Python
218 lines
7.9 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
"""Model download service."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from functools import total_ordering
|
|
from pathlib import Path
|
|
from typing import Any, Callable, List, Optional
|
|
|
|
from pydantic import BaseModel, Field, PrivateAttr
|
|
from pydantic.networks import AnyHttpUrl
|
|
|
|
|
|
class DownloadJobStatus(str, Enum):
|
|
"""State of a download job."""
|
|
|
|
WAITING = "waiting" # not enqueued, will not run
|
|
RUNNING = "running" # actively downloading
|
|
COMPLETED = "completed" # finished running
|
|
CANCELLED = "cancelled" # user cancelled
|
|
ERROR = "error" # terminated with an error message
|
|
|
|
|
|
class DownloadJobCancelledException(Exception):
|
|
"""This exception is raised when a download job is cancelled."""
|
|
|
|
|
|
class UnknownJobIDException(Exception):
|
|
"""This exception is raised when an invalid job id is referened."""
|
|
|
|
|
|
class ServiceInactiveException(Exception):
|
|
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
|
|
|
|
|
DownloadEventHandler = Callable[["DownloadJob"], None]
|
|
|
|
|
|
@total_ordering
|
|
class DownloadJob(BaseModel):
|
|
"""Class to monitor and control a model download request."""
|
|
|
|
# required variables to be passed in on creation
|
|
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
|
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
|
|
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
|
# automatically assigned on creation
|
|
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
|
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
|
|
|
# set internally during download process
|
|
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
|
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
|
|
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
|
job_ended: Optional[str] = Field(
|
|
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
|
)
|
|
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
|
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
|
|
|
# set when an error occurs
|
|
error_type: Optional[str] = Field(default=None, description="Name of exception that caused an error")
|
|
error: Optional[str] = Field(default=None, description="Traceback of the exception that caused an error")
|
|
|
|
# internal flag
|
|
_cancelled: bool = PrivateAttr(default=False)
|
|
|
|
# optional event handlers passed in on creation
|
|
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
_on_error: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
|
|
def __le__(self, other: "DownloadJob") -> bool:
|
|
"""Return True if this job's priority is less than another's."""
|
|
return self.priority <= other.priority
|
|
|
|
def cancel(self) -> None:
|
|
"""Call to cancel the job."""
|
|
self._cancelled = True
|
|
|
|
# cancelled and the callbacks are private attributes in order to prevent
|
|
# them from being serialized and/or used in the Json Schema
|
|
@property
|
|
def cancelled(self) -> bool:
|
|
"""Call to cancel the job."""
|
|
return self._cancelled
|
|
|
|
@property
|
|
def on_start(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_start event handler."""
|
|
return self._on_start
|
|
|
|
@property
|
|
def on_progress(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_progress event handler."""
|
|
return self._on_progress
|
|
|
|
@property
|
|
def on_complete(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_complete event handler."""
|
|
return self._on_complete
|
|
|
|
@property
|
|
def on_error(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_error event handler."""
|
|
return self._on_error
|
|
|
|
@property
|
|
def on_cancelled(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_cancelled event handler."""
|
|
return self._on_cancelled
|
|
|
|
def set_callbacks(
|
|
self,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadEventHandler] = None,
|
|
) -> None:
|
|
"""Set the callbacks for download events."""
|
|
self._on_start = on_start
|
|
self._on_progress = on_progress
|
|
self._on_complete = on_complete
|
|
self._on_error = on_error
|
|
self._on_cancelled = on_cancelled
|
|
|
|
|
|
class DownloadQueueServiceBase(ABC):
|
|
"""Multithreaded queue for downloading models via URL."""
|
|
|
|
@abstractmethod
|
|
def start(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Start the download worker threads."""
|
|
|
|
@abstractmethod
|
|
def stop(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Stop the download worker threads."""
|
|
|
|
@abstractmethod
|
|
def download(
|
|
self,
|
|
source: AnyHttpUrl,
|
|
dest: Path,
|
|
priority: int = 10,
|
|
access_token: Optional[str] = None,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadEventHandler] = None,
|
|
) -> DownloadJob:
|
|
"""
|
|
Create a download job.
|
|
|
|
:param source: Source of the download as a URL.
|
|
:param dest: Path to download to. See below.
|
|
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
|
events.
|
|
:returns: A DownloadJob object for monitoring the state of the download.
|
|
|
|
The `dest` argument is a Path object. Its behavior is:
|
|
|
|
1. If the path exists and is a directory, then the URL contents will be downloaded
|
|
into that directory using the filename indicated in the response's `Content-Disposition` field.
|
|
If no content-disposition is present, then the last component of the URL will be used (similar to
|
|
wget's behavior).
|
|
2. If the path does not exist, then it is taken as the name of a new file to create with the downloaded
|
|
content.
|
|
3. If the path exists and is an existing file, then the downloader will try to resume the download from
|
|
the end of the existing file.
|
|
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def list_jobs(self) -> List[DownloadJob]:
|
|
"""
|
|
List active download jobs.
|
|
|
|
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def id_to_job(self, id: int) -> DownloadJob:
|
|
"""
|
|
Return the DownloadJob corresponding to the integer ID.
|
|
|
|
:param id: ID of the DownloadJob.
|
|
|
|
Exceptions:
|
|
* UnknownJobIDException
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel_all_jobs(self):
|
|
"""Cancel all active and enquedjobs."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def prune_jobs(self):
|
|
"""Prune completed and errored queue items from the job list."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel_job(self, job: DownloadJob):
|
|
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def join(self):
|
|
"""Wait until all jobs are off the queue."""
|
|
pass
|