mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add multifile_download() method to download service
This commit is contained in:
parent
b48d4a049d
commit
0bf14c2830
@ -1,10 +1,17 @@
|
|||||||
"""Init file for download queue."""
|
"""Init file for download queue."""
|
||||||
|
|
||||||
from .download_base import DownloadJob, DownloadJobStatus, DownloadQueueServiceBase, UnknownJobIDException
|
from .download_base import (
|
||||||
|
DownloadJob,
|
||||||
|
DownloadJobStatus,
|
||||||
|
DownloadQueueServiceBase,
|
||||||
|
MultiFileDownloadJob,
|
||||||
|
UnknownJobIDException,
|
||||||
|
)
|
||||||
from .download_default import DownloadQueueService, TqdmProgress
|
from .download_default import DownloadQueueService, TqdmProgress
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DownloadJob",
|
"DownloadJob",
|
||||||
|
"MultiFileDownloadJob",
|
||||||
"DownloadQueueServiceBase",
|
"DownloadQueueServiceBase",
|
||||||
"DownloadQueueService",
|
"DownloadQueueService",
|
||||||
"TqdmProgress",
|
"TqdmProgress",
|
||||||
|
@ -5,11 +5,13 @@ from abc import ABC, abstractmethod
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import total_ordering
|
from functools import total_ordering
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Optional
|
from typing import Any, Callable, List, Optional, Set
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
|
|
||||||
|
|
||||||
class DownloadJobStatus(str, Enum):
|
class DownloadJobStatus(str, Enum):
|
||||||
"""State of a download job."""
|
"""State of a download job."""
|
||||||
@ -33,30 +35,19 @@ class ServiceInactiveException(Exception):
|
|||||||
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
||||||
|
|
||||||
|
|
||||||
DownloadEventHandler = Callable[["DownloadJob"], None]
|
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||||
DownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
DownloadExceptionHandler = Callable[["DownloadJobBase", Optional[Exception]], None]
|
||||||
|
|
||||||
|
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
||||||
|
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
||||||
|
|
||||||
|
|
||||||
@total_ordering
|
class DownloadJobBase(BaseModel):
|
||||||
class DownloadJob(BaseModel):
|
"""Base of classes to monitor and control downloads."""
|
||||||
"""Class to monitor and control a model download request."""
|
|
||||||
|
|
||||||
# required variables to be passed in on creation
|
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
||||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
||||||
dest: Path = Field(description="Destination of downloaded model on local disk; a directory or file path")
|
|
||||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
|
||||||
# automatically assigned on creation
|
|
||||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
|
||||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
|
||||||
|
|
||||||
# set internally during download process
|
|
||||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file")
|
|
||||||
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
|
||||||
job_ended: Optional[str] = Field(
|
|
||||||
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
|
||||||
)
|
|
||||||
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
|
||||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||||
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
total_bytes: int = Field(default=0, description="Total file size (bytes)")
|
||||||
|
|
||||||
@ -74,14 +65,6 @@ class DownloadJob(BaseModel):
|
|||||||
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
_on_cancelled: Optional[DownloadEventHandler] = PrivateAttr(default=None)
|
||||||
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
_on_error: Optional[DownloadExceptionHandler] = PrivateAttr(default=None)
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
"""Return hash of the string representation of this object, for indexing."""
|
|
||||||
return hash(str(self))
|
|
||||||
|
|
||||||
def __le__(self, other: "DownloadJob") -> bool:
|
|
||||||
"""Return True if this job's priority is less than another's."""
|
|
||||||
return self.priority <= other.priority
|
|
||||||
|
|
||||||
def cancel(self) -> None:
|
def cancel(self) -> None:
|
||||||
"""Call to cancel the job."""
|
"""Call to cancel the job."""
|
||||||
self._cancelled = True
|
self._cancelled = True
|
||||||
@ -98,6 +81,11 @@ class DownloadJob(BaseModel):
|
|||||||
"""Return true if job completed without errors."""
|
"""Return true if job completed without errors."""
|
||||||
return self.status == DownloadJobStatus.COMPLETED
|
return self.status == DownloadJobStatus.COMPLETED
|
||||||
|
|
||||||
|
@property
|
||||||
|
def waiting(self) -> bool:
|
||||||
|
"""Return true if the job is waiting to run."""
|
||||||
|
return self.status == DownloadJobStatus.WAITING
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def running(self) -> bool:
|
def running(self) -> bool:
|
||||||
"""Return true if the job is running."""
|
"""Return true if the job is running."""
|
||||||
@ -154,6 +142,39 @@ class DownloadJob(BaseModel):
|
|||||||
self._on_cancelled = on_cancelled
|
self._on_cancelled = on_cancelled
|
||||||
|
|
||||||
|
|
||||||
|
@total_ordering
|
||||||
|
class DownloadJob(DownloadJobBase):
|
||||||
|
"""Class to monitor and control a model download request."""
|
||||||
|
|
||||||
|
# required variables to be passed in on creation
|
||||||
|
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||||
|
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||||
|
# automatically assigned on creation
|
||||||
|
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||||
|
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||||
|
|
||||||
|
# set internally during download process
|
||||||
|
job_started: Optional[str] = Field(default=None, description="Timestamp for when the download job started")
|
||||||
|
job_ended: Optional[str] = Field(
|
||||||
|
default=None, description="Timestamp for when the download job ende1d (completed or errored)"
|
||||||
|
)
|
||||||
|
content_type: Optional[str] = Field(default=None, description="Content type of downloaded file")
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
"""Return hash of the string representation of this object, for indexing."""
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
def __le__(self, other: "DownloadJob") -> bool:
|
||||||
|
"""Return True if this job's priority is less than another's."""
|
||||||
|
return self.priority <= other.priority
|
||||||
|
|
||||||
|
|
||||||
|
class MultiFileDownloadJob(DownloadJobBase):
|
||||||
|
"""Class to monitor and control multifile downloads."""
|
||||||
|
|
||||||
|
download_parts: Set[DownloadJob] = Field(default_factory=set, description="List of download parts.")
|
||||||
|
|
||||||
|
|
||||||
class DownloadQueueServiceBase(ABC):
|
class DownloadQueueServiceBase(ABC):
|
||||||
"""Multithreaded queue for downloading models via URL."""
|
"""Multithreaded queue for downloading models via URL."""
|
||||||
|
|
||||||
@ -201,6 +222,33 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def multifile_download(
|
||||||
|
self,
|
||||||
|
parts: Set[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
on_start: Optional[DownloadEventHandler] = None,
|
||||||
|
on_progress: Optional[DownloadEventHandler] = None,
|
||||||
|
on_complete: Optional[DownloadEventHandler] = None,
|
||||||
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||||
|
on_error: Optional[DownloadExceptionHandler] = None,
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
"""
|
||||||
|
Create and enqueue a multifile download job.
|
||||||
|
|
||||||
|
:param parts: Set of URL / filename pairs
|
||||||
|
:param dest: Path to download to. See below.
|
||||||
|
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||||
|
events.
|
||||||
|
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
||||||
|
|
||||||
|
The `dest` argument is a Path object pointing to a directory. All downloads
|
||||||
|
with be placed inside this directory. The callbacks will receive the
|
||||||
|
MultiFileDownloadJob.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def submit_download_job(
|
def submit_download_job(
|
||||||
self,
|
self,
|
||||||
@ -262,7 +310,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
|
||||||
"""Wait until the indicated download job has reached a terminal state.
|
"""Wait until the indicated download job has reached a terminal state.
|
||||||
|
|
||||||
This will block until the indicated install job has completed,
|
This will block until the indicated install job has completed,
|
||||||
|
@ -18,6 +18,7 @@ from tqdm import tqdm
|
|||||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
|
from invokeai.backend.model_manager.metadata import RemoteModelFile
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
|
|
||||||
from .download_base import (
|
from .download_base import (
|
||||||
@ -27,6 +28,9 @@ from .download_base import (
|
|||||||
DownloadJobCancelledException,
|
DownloadJobCancelledException,
|
||||||
DownloadJobStatus,
|
DownloadJobStatus,
|
||||||
DownloadQueueServiceBase,
|
DownloadQueueServiceBase,
|
||||||
|
MultiFileDownloadEventHandler,
|
||||||
|
MultiFileDownloadExceptionHandler,
|
||||||
|
MultiFileDownloadJob,
|
||||||
ServiceInactiveException,
|
ServiceInactiveException,
|
||||||
UnknownJobIDException,
|
UnknownJobIDException,
|
||||||
)
|
)
|
||||||
@ -54,10 +58,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
"""
|
"""
|
||||||
self._app_config = app_config or get_config()
|
self._app_config = app_config or get_config()
|
||||||
self._jobs: Dict[int, DownloadJob] = {}
|
self._jobs: Dict[int, DownloadJob] = {}
|
||||||
|
self._download_part2parent: Dict[AnyHttpUrl, MultiFileDownloadJob] = {}
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._job_completed_event = threading.Event()
|
self._job_terminated_event = threading.Event()
|
||||||
self._worker_pool: Set[threading.Thread] = set()
|
self._worker_pool: Set[threading.Thread] = set()
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||||
@ -155,6 +160,49 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
)
|
)
|
||||||
return job
|
return job
|
||||||
|
|
||||||
|
def multifile_download(
|
||||||
|
self,
|
||||||
|
parts: Set[RemoteModelFile],
|
||||||
|
dest: Path,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
on_start: Optional[MultiFileDownloadEventHandler] = None,
|
||||||
|
on_progress: Optional[MultiFileDownloadEventHandler] = None,
|
||||||
|
on_complete: Optional[MultiFileDownloadEventHandler] = None,
|
||||||
|
on_cancelled: Optional[MultiFileDownloadEventHandler] = None,
|
||||||
|
on_error: Optional[MultiFileDownloadExceptionHandler] = None,
|
||||||
|
) -> MultiFileDownloadJob:
|
||||||
|
mfdj = MultiFileDownloadJob(dest=dest)
|
||||||
|
mfdj.set_callbacks(
|
||||||
|
on_start=on_start,
|
||||||
|
on_progress=on_progress,
|
||||||
|
on_complete=on_complete,
|
||||||
|
on_cancelled=on_cancelled,
|
||||||
|
on_error=on_error,
|
||||||
|
)
|
||||||
|
|
||||||
|
for part in parts:
|
||||||
|
url = part.url
|
||||||
|
path = dest / part.path
|
||||||
|
assert path.is_relative_to(dest), "only relative download paths accepted"
|
||||||
|
job = DownloadJob(
|
||||||
|
source=url,
|
||||||
|
dest=path,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
mfdj.download_parts.add(job)
|
||||||
|
self._download_part2parent[job.source] = mfdj
|
||||||
|
|
||||||
|
for download_job in mfdj.download_parts:
|
||||||
|
self.submit_download_job(
|
||||||
|
download_job,
|
||||||
|
on_start=self._mfd_started,
|
||||||
|
on_progress=self._mfd_progress,
|
||||||
|
on_complete=self._mfd_complete,
|
||||||
|
on_cancelled=self._mfd_cancelled,
|
||||||
|
on_error=self._mfd_error,
|
||||||
|
)
|
||||||
|
return mfdj
|
||||||
|
|
||||||
def join(self) -> None:
|
def join(self) -> None:
|
||||||
"""Wait for all jobs to complete."""
|
"""Wait for all jobs to complete."""
|
||||||
self._queue.join()
|
self._queue.join()
|
||||||
@ -187,7 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
If it is running it will be stopped.
|
If it is running it will be stopped.
|
||||||
job.status will be set to DownloadJobStatus.CANCELLED
|
job.status will be set to DownloadJobStatus.CANCELLED
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
if job.status in [DownloadJobStatus.WAITING, DownloadJobStatus.RUNNING]:
|
||||||
job.cancel()
|
job.cancel()
|
||||||
|
|
||||||
def cancel_all_jobs(self) -> None:
|
def cancel_all_jobs(self) -> None:
|
||||||
@ -196,12 +244,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if not job.in_terminal_state:
|
if not job.in_terminal_state:
|
||||||
self.cancel_job(job)
|
self.cancel_job(job)
|
||||||
|
|
||||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
|
||||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while not job.in_terminal_state:
|
while not job.in_terminal_state:
|
||||||
if self._job_completed_event.wait(timeout=0.25): # in case we miss an event
|
if self._job_terminated_event.wait(timeout=0.25): # in case we miss an event
|
||||||
self._job_completed_event.clear()
|
self._job_terminated_event.clear()
|
||||||
if timeout > 0 and time.time() - start > timeout:
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
raise TimeoutError("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
return job
|
return job
|
||||||
@ -230,22 +278,25 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
job.job_started = get_iso_timestamp()
|
job.job_started = get_iso_timestamp()
|
||||||
self._do_download(job)
|
self._do_download(job)
|
||||||
self._signal_job_complete(job)
|
self._signal_job_complete(job)
|
||||||
except (OSError, HTTPError) as excp:
|
|
||||||
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
|
||||||
job.error = traceback.format_exc()
|
|
||||||
self._signal_job_error(job, excp)
|
|
||||||
except DownloadJobCancelledException:
|
except DownloadJobCancelledException:
|
||||||
self._signal_job_cancelled(job)
|
self._signal_job_cancelled(job)
|
||||||
self._cleanup_cancelled_job(job)
|
self._cleanup_cancelled_job(job)
|
||||||
|
except Exception as excp:
|
||||||
|
job.error_type = excp.__class__.__name__ + f"({str(excp)})"
|
||||||
|
job.error = traceback.format_exc()
|
||||||
|
self._signal_job_error(job, excp)
|
||||||
finally:
|
finally:
|
||||||
job.job_ended = get_iso_timestamp()
|
job.job_ended = get_iso_timestamp()
|
||||||
self._job_completed_event.set() # signal a change to terminal state
|
self._job_terminated_event.set() # signal a change to terminal state
|
||||||
|
self._download_part2parent.pop(job.source, None) # if this is a subpart of a multipart job, remove it
|
||||||
|
self._job_terminated_event.set()
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
|
||||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||||
|
|
||||||
def _do_download(self, job: DownloadJob) -> None:
|
def _do_download(self, job: DownloadJob) -> None:
|
||||||
"""Do the actual download."""
|
"""Do the actual download."""
|
||||||
|
|
||||||
url = job.source
|
url = job.source
|
||||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||||
open_mode = "wb"
|
open_mode = "wb"
|
||||||
@ -339,7 +390,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
|
|
||||||
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
||||||
# Pull the token from config if it exists and matches the URL
|
# Pull the token from config if it exists and matches the URL
|
||||||
print(self._app_config)
|
|
||||||
token = None
|
token = None
|
||||||
for pair in self._app_config.remote_api_tokens or []:
|
for pair in self._app_config.remote_api_tokens or []:
|
||||||
if re.search(pair.url_regex, str(source)):
|
if re.search(pair.url_regex, str(source)):
|
||||||
@ -349,25 +399,13 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
|
|
||||||
def _signal_job_started(self, job: DownloadJob) -> None:
|
def _signal_job_started(self, job: DownloadJob) -> None:
|
||||||
job.status = DownloadJobStatus.RUNNING
|
job.status = DownloadJobStatus.RUNNING
|
||||||
if job.on_start:
|
self._execute_cb(job, "on_start")
|
||||||
try:
|
|
||||||
job.on_start(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_start callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
assert job.download_path
|
assert job.download_path
|
||||||
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
self._event_bus.emit_download_started(str(job.source), job.download_path.as_posix())
|
||||||
|
|
||||||
def _signal_job_progress(self, job: DownloadJob) -> None:
|
def _signal_job_progress(self, job: DownloadJob) -> None:
|
||||||
if job.on_progress:
|
self._execute_cb(job, "on_progress")
|
||||||
try:
|
|
||||||
job.on_progress(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_progress callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
assert job.download_path
|
assert job.download_path
|
||||||
self._event_bus.emit_download_progress(
|
self._event_bus.emit_download_progress(
|
||||||
@ -379,13 +417,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
|
|
||||||
def _signal_job_complete(self, job: DownloadJob) -> None:
|
def _signal_job_complete(self, job: DownloadJob) -> None:
|
||||||
job.status = DownloadJobStatus.COMPLETED
|
job.status = DownloadJobStatus.COMPLETED
|
||||||
if job.on_complete:
|
self._execute_cb(job, "on_complete")
|
||||||
try:
|
|
||||||
job.on_complete(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_complete callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
assert job.download_path
|
assert job.download_path
|
||||||
self._event_bus.emit_download_complete(
|
self._event_bus.emit_download_complete(
|
||||||
@ -396,26 +428,21 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
if job.status not in [DownloadJobStatus.RUNNING, DownloadJobStatus.WAITING]:
|
||||||
return
|
return
|
||||||
job.status = DownloadJobStatus.CANCELLED
|
job.status = DownloadJobStatus.CANCELLED
|
||||||
if job.on_cancelled:
|
self._execute_cb(job, "on_cancelled")
|
||||||
try:
|
|
||||||
job.on_cancelled(job)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_cancelled callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
self._event_bus.emit_download_cancelled(str(job.source))
|
self._event_bus.emit_download_cancelled(str(job.source))
|
||||||
|
|
||||||
|
# if multifile download, then signal the parent
|
||||||
|
if parent_job := self._download_part2parent.get(job.source, None):
|
||||||
|
if not parent_job.in_terminal_state:
|
||||||
|
parent_job.status = DownloadJobStatus.CANCELLED
|
||||||
|
self._execute_cb(parent_job, "on_cancelled")
|
||||||
|
|
||||||
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def _signal_job_error(self, job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
job.status = DownloadJobStatus.ERROR
|
job.status = DownloadJobStatus.ERROR
|
||||||
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
self._logger.error(f"{str(job.source)}: {traceback.format_exception(excp)}")
|
||||||
if job.on_error:
|
self._execute_cb(job, "on_error", excp)
|
||||||
try:
|
|
||||||
job.on_error(job, excp)
|
|
||||||
except Exception as e:
|
|
||||||
self._logger.error(
|
|
||||||
f"An error occurred while processing the on_error callback: {traceback.format_exception(e)}"
|
|
||||||
)
|
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
assert job.error_type
|
assert job.error_type
|
||||||
assert job.error
|
assert job.error
|
||||||
@ -430,6 +457,86 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
except OSError as excp:
|
except OSError as excp:
|
||||||
self._logger.warning(excp)
|
self._logger.warning(excp)
|
||||||
|
|
||||||
|
########################################
|
||||||
|
# callbacks used for multifile downloads
|
||||||
|
########################################
|
||||||
|
def _mfd_started(self, download_job: DownloadJob) -> None:
|
||||||
|
self._logger.info(f"File download started: {download_job.source}")
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
if mf_job.waiting:
|
||||||
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
mf_job.status = DownloadJobStatus.RUNNING
|
||||||
|
self._execute_cb(mf_job, "on_start")
|
||||||
|
|
||||||
|
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
if mf_job.cancelled:
|
||||||
|
for part in mf_job.download_parts:
|
||||||
|
self.cancel_job(part)
|
||||||
|
elif mf_job.running:
|
||||||
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
mf_job.bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
|
self._execute_cb(mf_job, "on_progress")
|
||||||
|
|
||||||
|
def _mfd_complete(self, download_job: DownloadJob) -> None:
|
||||||
|
self._logger.info(f"Download complete: {download_job.source}")
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
|
||||||
|
# are there any more active jobs left in this task?
|
||||||
|
if mf_job.running and all(x.complete for x in mf_job.download_parts):
|
||||||
|
mf_job.status = DownloadJobStatus.COMPLETED
|
||||||
|
self._execute_cb(mf_job, "on_complete")
|
||||||
|
|
||||||
|
# we're done with this sub-job
|
||||||
|
self._job_terminated_event.set()
|
||||||
|
|
||||||
|
def _mfd_cancelled(self, download_job: DownloadJob) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
assert mf_job is not None
|
||||||
|
|
||||||
|
if not mf_job.in_terminal_state:
|
||||||
|
self._logger.warning(f"Download cancelled: {download_job.source}")
|
||||||
|
mf_job.cancel()
|
||||||
|
|
||||||
|
for s in mf_job.download_parts:
|
||||||
|
self.cancel_job(s)
|
||||||
|
|
||||||
|
def _mfd_error(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
|
with self._lock:
|
||||||
|
mf_job = self._download_part2parent[download_job.source]
|
||||||
|
assert mf_job is not None
|
||||||
|
if not mf_job.in_terminal_state:
|
||||||
|
mf_job.status = download_job.status
|
||||||
|
mf_job.error = download_job.error
|
||||||
|
mf_job.error_type = download_job.error_type
|
||||||
|
self._execute_cb(mf_job, "on_error", excp)
|
||||||
|
self._logger.error(
|
||||||
|
f"Cancelling {mf_job.dest} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||||
|
)
|
||||||
|
for s in [x for x in mf_job.download_parts if x.running]:
|
||||||
|
self.cancel_job(s)
|
||||||
|
self._download_part2parent.pop(download_job.source)
|
||||||
|
self._job_terminated_event.set()
|
||||||
|
|
||||||
|
def _execute_cb(
|
||||||
|
self,
|
||||||
|
job: DownloadJob | MultiFileDownloadJob,
|
||||||
|
callback_name: str,
|
||||||
|
excp: Optional[Exception] = None,
|
||||||
|
) -> None:
|
||||||
|
if callback := getattr(job, callback_name, None):
|
||||||
|
args = [job, excp] if excp else [job]
|
||||||
|
try:
|
||||||
|
callback(*args)
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(
|
||||||
|
f"An error occurred while processing the {callback_name} callback: {traceback.format_exception(e)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pc_name_max(directory: str) -> int:
|
def get_pc_name_max(directory: str) -> int:
|
||||||
if hasattr(os, "pathconf"):
|
if hasattr(os, "pathconf"):
|
||||||
|
@ -689,7 +689,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._logger.debug(f"metadata={metadata}")
|
self._logger.debug(f"metadata={metadata}")
|
||||||
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
if metadata and isinstance(metadata, ModelMetadataWithFiles):
|
||||||
remote_files = metadata.download_urls(session=self._session)
|
remote_files = metadata.download_urls(session=self._session)
|
||||||
print(remote_files)
|
|
||||||
else:
|
else:
|
||||||
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
remote_files = [RemoteModelFile(url=source.url, path=Path("."), size=0)]
|
||||||
return self._import_remote_model(
|
return self._import_remote_model(
|
||||||
|
@ -37,7 +37,7 @@ class RemoteModelFile(BaseModel):
|
|||||||
|
|
||||||
url: AnyHttpUrl = Field(description="The url to download this model file")
|
url: AnyHttpUrl = Field(description="The url to download this model file")
|
||||||
path: Path = Field(description="The path to the file, relative to the model root")
|
path: Path = Field(description="The path to the file, relative to the model root")
|
||||||
size: int = Field(description="The size of this file, in bytes")
|
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
||||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator
|
from typing import Generator, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
@ -13,7 +13,8 @@ from requests_testadapter import TestAdapter, TestSession
|
|||||||
|
|
||||||
from invokeai.app.services.config import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
||||||
|
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
from tests.test_nodes import TestEventService
|
from tests.test_nodes import TestEventService
|
||||||
|
|
||||||
@ -67,11 +68,116 @@ def session() -> Session:
|
|||||||
return sess
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
events = set()
|
||||||
|
|
||||||
|
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
|
print(f"bytes = {job.bytes}")
|
||||||
|
events.add(job.status)
|
||||||
|
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=mm2_session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
job = queue.multifile_download(
|
||||||
|
parts=metadata.download_urls(session=mm2_session),
|
||||||
|
dest=tmp_path,
|
||||||
|
on_start=event_handler,
|
||||||
|
on_progress=event_handler,
|
||||||
|
on_complete=event_handler,
|
||||||
|
on_error=event_handler,
|
||||||
|
)
|
||||||
|
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||||
|
assert job.bytes > 0, "expected download bytes to be positive"
|
||||||
|
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||||
|
assert Path(
|
||||||
|
tmp_path, "sdxl-turbo/model_index.json"
|
||||||
|
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||||
|
assert Path(
|
||||||
|
tmp_path, "sdxl-turbo/text_encoder/config.json"
|
||||||
|
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
||||||
|
|
||||||
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
events = set()
|
||||||
|
|
||||||
|
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
|
events.add(job.status)
|
||||||
|
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=mm2_session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
files = metadata.download_urls(session=mm2_session)
|
||||||
|
# this will give a 404 error
|
||||||
|
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
|
||||||
|
job = queue.multifile_download(
|
||||||
|
parts=files,
|
||||||
|
dest=tmp_path,
|
||||||
|
on_start=event_handler,
|
||||||
|
on_progress=event_handler,
|
||||||
|
on_complete=event_handler,
|
||||||
|
on_error=event_handler,
|
||||||
|
)
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
||||||
|
assert "HTTPError(NOT FOUND)" in job.error_type
|
||||||
|
assert DownloadJobStatus.ERROR in events
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=15, method="thread")
|
||||||
|
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
|
||||||
|
event_bus = TestEventService()
|
||||||
|
|
||||||
|
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||||
|
queue.start()
|
||||||
|
|
||||||
|
cancelled = False
|
||||||
|
|
||||||
|
def cancelled_callback(job: DownloadJob) -> None:
|
||||||
|
nonlocal cancelled
|
||||||
|
cancelled = True
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Join took too long to return")
|
||||||
|
|
||||||
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
|
||||||
|
job = queue.multifile_download(
|
||||||
|
parts=metadata.download_urls(session=mm2_session),
|
||||||
|
dest=tmp_path,
|
||||||
|
on_cancelled=cancelled_callback,
|
||||||
|
)
|
||||||
|
queue.cancel_job(job)
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus.CANCELLED
|
||||||
|
assert cancelled
|
||||||
|
events = event_bus.events
|
||||||
|
assert "download_cancelled" in [x.event_name for x in events]
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||||
events = set()
|
events = set()
|
||||||
|
|
||||||
def event_handler(job: DownloadJob) -> None:
|
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
events.add(job.status)
|
events.add(job.status)
|
||||||
|
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user