clean up type checking for single file and multifile download job callbacks

This commit is contained in:
Lincoln Stein
2024-05-13 18:31:40 -04:00
parent 0bf14c2830
commit 287c679f7b
3 changed files with 186 additions and 172 deletions

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from enum import Enum
from functools import total_ordering
from pathlib import Path
from typing import Any, Callable, List, Optional, Set
from typing import Any, Callable, List, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl
@ -35,12 +35,12 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJobBase"], None]
DownloadExceptionHandler = Callable[["DownloadJobBase", Optional[Exception]], None]
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
class DownloadJobBase(BaseModel):
"""Base of classes to monitor and control downloads."""
@ -228,6 +228,7 @@ class DownloadQueueServiceBase(ABC):
parts: Set[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
@ -239,6 +240,11 @@ class DownloadQueueServiceBase(ABC):
:param parts: Set of URL / filename pairs
:param dest: Path to download to. See below.
:param access_token: Access token to download the indicated files. If not provided,
each file's URL may be matched to an access token using the config file matching
system.
:param submit_job: If true [default] then submit the job for execution. Otherwise,
you will need to pass the job to submit_multifile_download().
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events.
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
@ -249,6 +255,15 @@ class DownloadQueueServiceBase(ABC):
"""
pass
@abstractmethod
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
"""
Enqueue a previously-created multi-file download job.
:param job: A MultiFileDownloadJob created with multifile_download()
"""
pass
@abstractmethod
def submit_download_job(
self,

View File

@ -25,11 +25,10 @@ from .download_base import (
DownloadEventHandler,
DownloadExceptionHandler,
DownloadJob,
DownloadJobBase,
DownloadJobCancelledException,
DownloadJobStatus,
DownloadQueueServiceBase,
MultiFileDownloadEventHandler,
MultiFileDownloadExceptionHandler,
MultiFileDownloadJob,
ServiceInactiveException,
UnknownJobIDException,
@ -165,11 +164,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
parts: Set[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
on_start: Optional[MultiFileDownloadEventHandler] = None,
on_progress: Optional[MultiFileDownloadEventHandler] = None,
on_complete: Optional[MultiFileDownloadEventHandler] = None,
on_cancelled: Optional[MultiFileDownloadEventHandler] = None,
on_error: Optional[MultiFileDownloadExceptionHandler] = None,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest)
mfdj.set_callbacks(
@ -191,8 +190,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
self.submit_multifile_download(mfdj)
return mfdj
for download_job in mfdj.download_parts:
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
for download_job in job.download_parts:
self.submit_download_job(
download_job,
on_start=self._mfd_started,
@ -201,7 +203,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
on_cancelled=self._mfd_cancelled,
on_error=self._mfd_error,
)
return mfdj
def join(self) -> None:
"""Wait for all jobs to complete."""