clean up type checking for single file and multifile download job callbacks

This commit is contained in:
Lincoln Stein 2024-05-13 18:31:40 -04:00
parent 0bf14c2830
commit 287c679f7b
3 changed files with 186 additions and 172 deletions

View File

@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from functools import total_ordering from functools import total_ordering
from pathlib import Path from pathlib import Path
from typing import Any, Callable, List, Optional, Set from typing import Any, Callable, List, Optional, Set, Union
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
@ -35,12 +35,12 @@ class ServiceInactiveException(Exception):
"""This exception is raised when user attempts to initiate a download before the service is started.""" """This exception is raised when user attempts to initiate a download before the service is started."""
DownloadEventHandler = Callable[["DownloadJobBase"], None] SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
DownloadExceptionHandler = Callable[["DownloadJobBase", Optional[Exception]], None] SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None] MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None] MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
class DownloadJobBase(BaseModel): class DownloadJobBase(BaseModel):
"""Base of classes to monitor and control downloads.""" """Base of classes to monitor and control downloads."""
@ -228,6 +228,7 @@ class DownloadQueueServiceBase(ABC):
parts: Set[RemoteModelFile], parts: Set[RemoteModelFile],
dest: Path, dest: Path,
access_token: Optional[str] = None, access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None, on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None, on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None, on_complete: Optional[DownloadEventHandler] = None,
@ -239,6 +240,11 @@ class DownloadQueueServiceBase(ABC):
:param parts: Set of URL / filename pairs :param parts: Set of URL / filename pairs
:param dest: Path to download to. See below. :param dest: Path to download to. See below.
:param access_token: Access token to download the indicated files. If not provided,
each file's URL may be matched to an access token using the config file matching
system.
:param submit_job: If true [default] then submit the job for execution. Otherwise,
you will need to pass the job to submit_multifile_download().
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated :param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
events. events.
:returns: A MultiFileDownloadJob object for monitoring the state of the download. :returns: A MultiFileDownloadJob object for monitoring the state of the download.
@ -249,6 +255,15 @@ class DownloadQueueServiceBase(ABC):
""" """
pass pass
@abstractmethod
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
"""
Enqueue a previously-created multi-file download job.
:param job: A MultiFileDownloadJob created with multifile_download()
"""
pass
@abstractmethod @abstractmethod
def submit_download_job( def submit_download_job(
self, self,

View File

@ -25,11 +25,10 @@ from .download_base import (
DownloadEventHandler, DownloadEventHandler,
DownloadExceptionHandler, DownloadExceptionHandler,
DownloadJob, DownloadJob,
DownloadJobBase,
DownloadJobCancelledException, DownloadJobCancelledException,
DownloadJobStatus, DownloadJobStatus,
DownloadQueueServiceBase, DownloadQueueServiceBase,
MultiFileDownloadEventHandler,
MultiFileDownloadExceptionHandler,
MultiFileDownloadJob, MultiFileDownloadJob,
ServiceInactiveException, ServiceInactiveException,
UnknownJobIDException, UnknownJobIDException,
@ -165,11 +164,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
parts: Set[RemoteModelFile], parts: Set[RemoteModelFile],
dest: Path, dest: Path,
access_token: Optional[str] = None, access_token: Optional[str] = None,
on_start: Optional[MultiFileDownloadEventHandler] = None, on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[MultiFileDownloadEventHandler] = None, on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[MultiFileDownloadEventHandler] = None, on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[MultiFileDownloadEventHandler] = None, on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[MultiFileDownloadExceptionHandler] = None, on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob: ) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest) mfdj = MultiFileDownloadJob(dest=dest)
mfdj.set_callbacks( mfdj.set_callbacks(
@ -191,8 +190,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
) )
mfdj.download_parts.add(job) mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj self._download_part2parent[job.source] = mfdj
self.submit_multifile_download(mfdj)
return mfdj
for download_job in mfdj.download_parts: def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
for download_job in job.download_parts:
self.submit_download_job( self.submit_download_job(
download_job, download_job,
on_start=self._mfd_started, on_start=self._mfd_started,
@ -201,7 +203,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
on_cancelled=self._mfd_cancelled, on_cancelled=self._mfd_cancelled,
on_error=self._mfd_error, on_error=self._mfd_error,
) )
return mfdj
def join(self) -> None: def join(self) -> None:
"""Wait for all jobs to complete.""" """Wait for all jobs to complete."""

View File

@ -68,6 +68,161 @@ def session() -> Session:
return sess return sess
@pytest.mark.timeout(timeout=10, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
events = set()
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
assert isinstance(job.id, int), "expected the job id to be numeric"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None:
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
queue.join()
jobs = queue.list_jobs()
print(jobs)
assert len(jobs) == 2
jobs_dict = {str(x.source): x for x in jobs}
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
)
queue.join()
events = event_bus.events
assert len(events) == 3
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
assert events[0].event_name == "download_started"
assert events[1].event_name == "download_progress"
assert events[1].payload["total_bytes"] > 0
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
assert events[2].event_name == "download_complete"
assert events[2].payload["total_bytes"] == 32029
# test a failure
event_bus.events = [] # reset our accumulator
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
queue.join()
events = event_bus.events
print("\n".join([x.model_dump_json() for x in events]))
assert len(events) == 1
assert events[0].event_name == "download_error"
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
assert events[0].payload["error"] is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
callback_ran = False
def broken_callback(job: DownloadJob) -> None:
nonlocal callback_ran
callback_ran = True
print(1 / 0) # deliberate error here
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_progress=broken_callback,
)
queue.join()
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
assert Path(tmp_path, "mock12345.safetensors").exists()
assert callback_ran
# LS: The pytest capsys fixture does not seem to be working. I can see the
# correct stderr message in the pytest log, but it is not appearing in
# capsys.readouterr().
# captured = capsys.readouterr()
# assert re.search("division by zero", captured.err)
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
cancelled = False
def slow_callback(job: DownloadJob) -> None:
time.sleep(2)
def cancelled_callback(job: DownloadJob) -> None:
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_start=slow_callback,
on_cancelled=cancelled_callback,
)
queue.cancel_job(job)
queue.join()
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session) fetcher = HuggingFaceMetadataFetch(mm2_session)
@ -139,7 +294,7 @@ def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=15, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None: def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
event_bus = TestEventService() event_bus = TestEventService()
@ -172,163 +327,6 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
assert "download_cancelled" in [x.event_name for x in events] assert "download_cancelled" in [x.event_name for x in events]
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
events = set()
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_start=event_handler,
on_progress=event_handler,
on_complete=event_handler,
on_error=event_handler,
)
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
assert isinstance(job.id, int), "expected the job id to be numeric"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None:
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
queue.join()
jobs = queue.list_jobs()
print(jobs)
assert len(jobs) == 2
jobs_dict = {str(x.source): x for x in jobs}
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
)
queue.join()
events = event_bus.events
assert len(events) == 3
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
assert events[0].event_name == "download_started"
assert events[1].event_name == "download_progress"
assert events[1].payload["total_bytes"] > 0
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
assert events[2].event_name == "download_complete"
assert events[2].payload["total_bytes"] == 32029
# test a failure
event_bus.events = [] # reset our accumulator
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
queue.join()
events = event_bus.events
print("\n".join([x.model_dump_json() for x in events]))
assert len(events) == 1
assert events[0].event_name == "download_error"
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
assert events[0].payload["error"] is not None
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
queue.stop()
@pytest.mark.timeout(timeout=20, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,
)
queue.start()
callback_ran = False
def broken_callback(job: DownloadJob) -> None:
nonlocal callback_ran
callback_ran = True
print(1 / 0) # deliberate error here
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_progress=broken_callback,
)
queue.join()
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
assert Path(tmp_path, "mock12345.safetensors").exists()
assert callback_ran
# LS: The pytest capsys fixture does not seem to be working. I can see the
# correct stderr message in the pytest log, but it is not appearing in
# capsys.readouterr().
# captured = capsys.readouterr()
# assert re.search("division by zero", captured.err)
queue.stop()
@pytest.mark.timeout(timeout=15, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue.start()
cancelled = False
def slow_callback(job: DownloadJob) -> None:
time.sleep(2)
def cancelled_callback(job: DownloadJob) -> None:
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
on_start=slow_callback,
on_cancelled=cancelled_callback,
)
queue.cancel_job(job)
queue.join()
assert job.status == DownloadJobStatus.CANCELLED
assert cancelled
events = event_bus.events
assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()
@contextmanager @contextmanager
def clear_config() -> Generator[None, None, None]: def clear_config() -> Generator[None, None, None]:
try: try: