mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
341 lines
13 KiB
Python
341 lines
13 KiB
Python
# Copyright (c) 2023 Lincoln D. Stein and the InvokeAI Development Team
|
|
"""Model download service."""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from enum import Enum
|
|
from functools import total_ordering
|
|
from pathlib import Path
|
|
from typing import Any, Callable, List, Optional, Set, Union
|
|
|
|
from pydantic import BaseModel, Field, PrivateAttr
|
|
from pydantic.networks import AnyHttpUrl
|
|
|
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
|
|
|
|
|
class DownloadJobStatus(str, Enum):
|
|
"""State of a download job."""
|
|
|
|
WAITING = "waiting" # not enqueued, will not run
|
|
RUNNING = "running" # actively downloading
|
|
COMPLETED = "completed" # finished running
|
|
CANCELLED = "cancelled" # user cancelled
|
|
ERROR = "error" # terminated with an error message
|
|
|
|
|
|
class DownloadJobCancelledException(Exception):
|
|
"""This exception is raised when a download job is cancelled."""
|
|
|
|
|
|
class UnknownJobIDException(Exception):
|
|
"""This exception is raised when an invalid job id is referened."""
|
|
|
|
|
|
class ServiceInactiveException(Exception):
|
|
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
|
|
|
|
|
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
|
|
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
|
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
|
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
|
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
|
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
|
|
|
|
|
class DownloadJobBase(BaseModel):
|
|
"""Base of classes to monitor and control downloads."""
|
|
|
|
# automatically assigned on creation
|
|
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
|
|
|
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
|
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
|
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
|
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
|
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
|
|
|
# set when an error occurs
|
|
error_type: Optional[str] = Field(default=None, description="Name of exception that caused an error")
|
|
error: Optional[str] = Field(default=None, description="Traceback of the exception that caused an error")
|
|
|
|
# internal flag
|
|
_cancelled: bool = PrivateAttr(default=False)
|
|
|
|
# optional event handlers passed in on creation
|
|
_on_start: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
_on_progress: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
_on_complete: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
|
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
|
|
|
def cancel(self) -> None:
|
|
"""Call to cancel the job."""
|
|
self._cancelled = True
|
|
|
|
# cancelled and the callbacks are private attributes in order to prevent
|
|
# them from being serialized and/or used in the Json Schema
|
|
@property
|
|
def cancelled(self) -> bool:
|
|
"""Call to cancel the job."""
|
|
return self._cancelled
|
|
|
|
@property
|
|
def complete(self) -> bool:
|
|
"""Return true if job completed without errors."""
|
|
return self.status == DownloadJobStatus.COMPLETED
|
|
|
|
@property
|
|
def waiting(self) -> bool:
|
|
"""Return true if the job is waiting to run."""
|
|
return self.status == DownloadJobStatus.WAITING
|
|
|
|
@property
|
|
def running(self) -> bool:
|
|
"""Return true if the job is running."""
|
|
return self.status == DownloadJobStatus.RUNNING
|
|
|
|
@property
|
|
def errored(self) -> bool:
|
|
"""Return true if the job is errored."""
|
|
return self.status == DownloadJobStatus.ERROR
|
|
|
|
@property
|
|
def in_terminal_state(self) -> bool:
|
|
"""Return true if job has finished, one way or another."""
|
|
return self.status not in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]
|
|
|
|
@property
|
|
def on_start(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_start event handler."""
|
|
return self._on_start
|
|
|
|
@property
|
|
def on_progress(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_progress event handler."""
|
|
return self._on_progress
|
|
|
|
@property
|
|
def on_complete(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_complete event handler."""
|
|
return self._on_complete
|
|
|
|
@property
|
|
def on_error(self) -> Optional[DownloadExceptionHandler]:
|
|
"""Return the on_error event handler."""
|
|
return self._on_error
|
|
|
|
@property
|
|
def on_cancelled(self) -> Optional[DownloadEventHandler]:
|
|
"""Return the on_cancelled event handler."""
|
|
return self._on_cancelled
|
|
|
|
def set_callbacks(
|
|
self,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadExceptionHandler] = None,
|
|
) -> None:
|
|
"""Set the callbacks for download events."""
|
|
self._on_start = on_start
|
|
self._on_progress = on_progress
|
|
self._on_complete = on_complete
|
|
self._on_error = on_error
|
|
self._on_cancelled = on_cancelled
|
|
|
|
|
|
@total_ordering
|
|
class DownloadJob(DownloadJobBase):
|
|
"""Class to monitor and control a model download request."""
|
|
|
|
# required variables to be passed in on creation
|
|
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
|
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
|
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
|
|
|
# set internally during download process
|
|
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
|
job_ended: Optional[str] = Field(
|
|
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
|
)
|
|
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
|
|
|
def __hash__(self) -> int:
|
|
"""Return hash of the string representation of this object, for indexing."""
|
|
return hash(str(self))
|
|
|
|
def __le__(self, other: "DownloadJob") -> bool:
|
|
"""Return True if this job's priority is less than another's."""
|
|
return self.priority <= other.priority
|
|
|
|
|
|
class MultiFileDownloadJob(DownloadJobBase):
|
|
"""Class to monitor and control multifile downloads."""
|
|
|
|
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
|
|
|
|
|
|
class DownloadQueueServiceBase(ABC):
|
|
"""Multithreaded queue for downloading models via URL."""
|
|
|
|
@abstractmethod
|
|
def start(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Start the download worker threads."""
|
|
|
|
@abstractmethod
|
|
def stop(self, *args: Any, **kwargs: Any) -> None:
|
|
"""Stop the download worker threads."""
|
|
|
|
@abstractmethod
|
|
def download(
|
|
self,
|
|
source: AnyHttpUrl,
|
|
dest: Path,
|
|
priority: int = 10,
|
|
access_token: Optional[str] = None,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadExceptionHandler] = None,
|
|
) -> DownloadJob:
|
|
"""
|
|
Create and enqueue download job.
|
|
|
|
:param source: Source of the download as a URL.
|
|
:param dest: Path to download to. See below.
|
|
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
|
events.
|
|
:returns: A DownloadJob object for monitoring the state of the download.
|
|
|
|
The `dest` argument is a Path object. Its behavior is:
|
|
|
|
1. If the path exists and is a directory, then the URL contents will be downloaded
|
|
into that directory using the filename indicated in the response's `Content-Disposition` field.
|
|
If no content-disposition is present, then the last component of the URL will be used (similar to
|
|
wget's behavior).
|
|
2. If the path does not exist, then it is taken as the name of a new file to create with the downloaded
|
|
content.
|
|
3. If the path exists and is an existing file, then the downloader will try to resume the download from
|
|
the end of the existing file.
|
|
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def multifile_download(
|
|
self,
|
|
parts: List[RemoteModelFile],
|
|
dest: Path,
|
|
access_token: Optional[str] = None,
|
|
submit_job: bool = True,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadExceptionHandler] = None,
|
|
) -> MultiFileDownloadJob:
|
|
"""
|
|
Create and enqueue a multifile download job.
|
|
|
|
:param parts: Set of URL / filename pairs
|
|
:param dest: Path to download to. See below.
|
|
:param access_token: Access token to download the indicated files. If not provided,
|
|
each file's URL may be matched to an access token using the config file matching
|
|
system.
|
|
:param submit_job: If true [default] then submit the job for execution. Otherwise,
|
|
you will need to pass the job to submit_multifile_download().
|
|
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
|
events.
|
|
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
|
|
|
The `dest` argument is a Path object pointing to a directory. All downloads
|
|
with be placed inside this directory. The callbacks will receive the
|
|
MultiFileDownloadJob.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
|
"""
|
|
Enqueue a previously-created multi-file download job.
|
|
|
|
:param job: A MultiFileDownloadJob created with multifile_download()
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def submit_download_job(
|
|
self,
|
|
job: DownloadJob,
|
|
on_start: Optional[DownloadEventHandler] = None,
|
|
on_progress: Optional[DownloadEventHandler] = None,
|
|
on_complete: Optional[DownloadEventHandler] = None,
|
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
|
on_error: Optional[DownloadExceptionHandler] = None,
|
|
) -> None:
|
|
"""
|
|
Enqueue a download job.
|
|
|
|
:param job: The DownloadJob
|
|
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
|
events.
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def list_jobs(self) -> List[DownloadJob]:
|
|
"""
|
|
List active download jobs.
|
|
|
|
:returns List[DownloadJob]: List of download jobs whose state is not "completed."
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def id_to_job(self, id: int) -> DownloadJob:
|
|
"""
|
|
Return the DownloadJob corresponding to the integer ID.
|
|
|
|
:param id: ID of the DownloadJob.
|
|
|
|
Exceptions:
|
|
* UnknownJobIDException
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel_all_jobs(self) -> None:
|
|
"""Cancel all active and enquedjobs."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def prune_jobs(self) -> None:
|
|
"""Prune completed and errored queue items from the job list."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def cancel_job(self, job: DownloadJobBase) -> None:
|
|
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def join(self) -> None:
|
|
"""Wait until all jobs are off the queue."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
|
"""Wait until the indicated download job has reached a terminal state.
|
|
|
|
This will block until the indicated install job has completed,
|
|
been cancelled, or errored out.
|
|
|
|
:param job: The job to wait on.
|
|
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
|
the job hasn't completed within the indicated time.
|
|
"""
|
|
pass
|