mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up type checking for single file and multifile download job callbacks
This commit is contained in:
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Optional, Set
|
||||
from typing import Any, Callable, List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
@ -35,12 +35,12 @@ class ServiceInactiveException(Exception):
|
||||
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||
DownloadExceptionHandler = Callable[["DownloadJobBase", Optional[Exception]], None]
|
||||
|
||||
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
|
||||
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
||||
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
||||
|
||||
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||
|
||||
class DownloadJobBase(BaseModel):
|
||||
"""Base of classes to monitor and control downloads."""
|
||||
@ -228,6 +228,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
parts: Set[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
@ -239,6 +240,11 @@ class DownloadQueueServiceBase(ABC):
|
||||
|
||||
:param parts: Set of URL / filename pairs
|
||||
:param dest: Path to download to. See below.
|
||||
:param access_token: Access token to download the indicated files. If not provided,
|
||||
each file's URL may be matched to an access token using the config file matching
|
||||
system.
|
||||
:param submit_job: If true [default] then submit the job for execution. Otherwise,
|
||||
you will need to pass the job to submit_multifile_download().
|
||||
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||
events.
|
||||
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
||||
@ -249,6 +255,15 @@ class DownloadQueueServiceBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||
"""
|
||||
Enqueue a previously-created multi-file download job.
|
||||
|
||||
:param job: A MultiFileDownloadJob created with multifile_download()
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def submit_download_job(
|
||||
self,
|
||||
|
@ -25,11 +25,10 @@ from .download_base import (
|
||||
DownloadEventHandler,
|
||||
DownloadExceptionHandler,
|
||||
DownloadJob,
|
||||
DownloadJobBase,
|
||||
DownloadJobCancelledException,
|
||||
DownloadJobStatus,
|
||||
DownloadQueueServiceBase,
|
||||
MultiFileDownloadEventHandler,
|
||||
MultiFileDownloadExceptionHandler,
|
||||
MultiFileDownloadJob,
|
||||
ServiceInactiveException,
|
||||
UnknownJobIDException,
|
||||
@ -165,11 +164,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
parts: Set[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
on_start: Optional[MultiFileDownloadEventHandler] = None,
|
||||
on_progress: Optional[MultiFileDownloadEventHandler] = None,
|
||||
on_complete: Optional[MultiFileDownloadEventHandler] = None,
|
||||
on_cancelled: Optional[MultiFileDownloadEventHandler] = None,
|
||||
on_error: Optional[MultiFileDownloadExceptionHandler] = None,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> MultiFileDownloadJob:
|
||||
mfdj = MultiFileDownloadJob(dest=dest)
|
||||
mfdj.set_callbacks(
|
||||
@ -191,8 +190,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
)
|
||||
mfdj.download_parts.add(job)
|
||||
self._download_part2parent[job.source] = mfdj
|
||||
self.submit_multifile_download(mfdj)
|
||||
return mfdj
|
||||
|
||||
for download_job in mfdj.download_parts:
|
||||
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||
for download_job in job.download_parts:
|
||||
self.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._mfd_started,
|
||||
@ -201,7 +203,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
on_cancelled=self._mfd_cancelled,
|
||||
on_error=self._mfd_error,
|
||||
)
|
||||
return mfdj
|
||||
|
||||
def join(self) -> None:
|
||||
"""Wait for all jobs to complete."""
|
||||
|
Reference in New Issue
Block a user