mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
clean up type checking for single file and multifile download job callbacks
This commit is contained in:
parent
0bf14c2830
commit
287c679f7b
@ -5,7 +5,7 @@ from abc import ABC, abstractmethod
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import total_ordering
|
from functools import total_ordering
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, List, Optional, Set
|
from typing import Any, Callable, List, Optional, Set, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
@ -35,12 +35,12 @@ class ServiceInactiveException(Exception):
|
|||||||
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
"""This exception is raised when user attempts to initiate a download before the service is started."""
|
||||||
|
|
||||||
|
|
||||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
SingleFileDownloadEventHandler = Callable[["DownloadJob"], None]
|
||||||
DownloadExceptionHandler = Callable[["DownloadJobBase", Optional[Exception]], None]
|
SingleFileDownloadExceptionHandler = Callable[["DownloadJob", Optional[Exception]], None]
|
||||||
|
|
||||||
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
MultiFileDownloadEventHandler = Callable[["MultiFileDownloadJob"], None]
|
||||||
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[Exception]], None]
|
||||||
|
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||||
|
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||||
|
|
||||||
class DownloadJobBase(BaseModel):
|
class DownloadJobBase(BaseModel):
|
||||||
"""Base of classes to monitor and control downloads."""
|
"""Base of classes to monitor and control downloads."""
|
||||||
@ -228,6 +228,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
parts: Set[RemoteModelFile],
|
parts: Set[RemoteModelFile],
|
||||||
dest: Path,
|
dest: Path,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
on_start: Optional[DownloadEventHandler] = None,
|
on_start: Optional[DownloadEventHandler] = None,
|
||||||
on_progress: Optional[DownloadEventHandler] = None,
|
on_progress: Optional[DownloadEventHandler] = None,
|
||||||
on_complete: Optional[DownloadEventHandler] = None,
|
on_complete: Optional[DownloadEventHandler] = None,
|
||||||
@ -239,6 +240,11 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
|
|
||||||
:param parts: Set of URL / filename pairs
|
:param parts: Set of URL / filename pairs
|
||||||
:param dest: Path to download to. See below.
|
:param dest: Path to download to. See below.
|
||||||
|
:param access_token: Access token to download the indicated files. If not provided,
|
||||||
|
each file's URL may be matched to an access token using the config file matching
|
||||||
|
system.
|
||||||
|
:param submit_job: If true [default] then submit the job for execution. Otherwise,
|
||||||
|
you will need to pass the job to submit_multifile_download().
|
||||||
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
:param on_start, on_progress, on_complete, on_error: Callbacks for the indicated
|
||||||
events.
|
events.
|
||||||
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
:returns: A MultiFileDownloadJob object for monitoring the state of the download.
|
||||||
@ -249,6 +255,15 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||||
|
"""
|
||||||
|
Enqueue a previously-created multi-file download job.
|
||||||
|
|
||||||
|
:param job: A MultiFileDownloadJob created with multifile_download()
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def submit_download_job(
|
def submit_download_job(
|
||||||
self,
|
self,
|
||||||
|
@ -25,11 +25,10 @@ from .download_base import (
|
|||||||
DownloadEventHandler,
|
DownloadEventHandler,
|
||||||
DownloadExceptionHandler,
|
DownloadExceptionHandler,
|
||||||
DownloadJob,
|
DownloadJob,
|
||||||
|
DownloadJobBase,
|
||||||
DownloadJobCancelledException,
|
DownloadJobCancelledException,
|
||||||
DownloadJobStatus,
|
DownloadJobStatus,
|
||||||
DownloadQueueServiceBase,
|
DownloadQueueServiceBase,
|
||||||
MultiFileDownloadEventHandler,
|
|
||||||
MultiFileDownloadExceptionHandler,
|
|
||||||
MultiFileDownloadJob,
|
MultiFileDownloadJob,
|
||||||
ServiceInactiveException,
|
ServiceInactiveException,
|
||||||
UnknownJobIDException,
|
UnknownJobIDException,
|
||||||
@ -165,11 +164,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
parts: Set[RemoteModelFile],
|
parts: Set[RemoteModelFile],
|
||||||
dest: Path,
|
dest: Path,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
on_start: Optional[MultiFileDownloadEventHandler] = None,
|
on_start: Optional[DownloadEventHandler] = None,
|
||||||
on_progress: Optional[MultiFileDownloadEventHandler] = None,
|
on_progress: Optional[DownloadEventHandler] = None,
|
||||||
on_complete: Optional[MultiFileDownloadEventHandler] = None,
|
on_complete: Optional[DownloadEventHandler] = None,
|
||||||
on_cancelled: Optional[MultiFileDownloadEventHandler] = None,
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||||
on_error: Optional[MultiFileDownloadExceptionHandler] = None,
|
on_error: Optional[DownloadExceptionHandler] = None,
|
||||||
) -> MultiFileDownloadJob:
|
) -> MultiFileDownloadJob:
|
||||||
mfdj = MultiFileDownloadJob(dest=dest)
|
mfdj = MultiFileDownloadJob(dest=dest)
|
||||||
mfdj.set_callbacks(
|
mfdj.set_callbacks(
|
||||||
@ -191,8 +190,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
)
|
)
|
||||||
mfdj.download_parts.add(job)
|
mfdj.download_parts.add(job)
|
||||||
self._download_part2parent[job.source] = mfdj
|
self._download_part2parent[job.source] = mfdj
|
||||||
|
self.submit_multifile_download(mfdj)
|
||||||
|
return mfdj
|
||||||
|
|
||||||
for download_job in mfdj.download_parts:
|
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||||
|
for download_job in job.download_parts:
|
||||||
self.submit_download_job(
|
self.submit_download_job(
|
||||||
download_job,
|
download_job,
|
||||||
on_start=self._mfd_started,
|
on_start=self._mfd_started,
|
||||||
@ -201,7 +203,6 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
on_cancelled=self._mfd_cancelled,
|
on_cancelled=self._mfd_cancelled,
|
||||||
on_error=self._mfd_error,
|
on_error=self._mfd_error,
|
||||||
)
|
)
|
||||||
return mfdj
|
|
||||||
|
|
||||||
def join(self) -> None:
|
def join(self) -> None:
|
||||||
"""Wait for all jobs to complete."""
|
"""Wait for all jobs to complete."""
|
||||||
|
@ -68,6 +68,161 @@ def session() -> Session:
|
|||||||
return sess
|
return sess
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||||
|
events = set()
|
||||||
|
|
||||||
|
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
|
events.add(job.status)
|
||||||
|
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
job = queue.download(
|
||||||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
|
dest=tmp_path,
|
||||||
|
on_start=event_handler,
|
||||||
|
on_progress=event_handler,
|
||||||
|
on_complete=event_handler,
|
||||||
|
on_error=event_handler,
|
||||||
|
)
|
||||||
|
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
|
||||||
|
assert isinstance(job.id, int), "expected the job id to be numeric"
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||||
|
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||||
|
|
||||||
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_errors(tmp_path: Path, session: Session) -> None:
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
|
||||||
|
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
|
||||||
|
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
|
||||||
|
|
||||||
|
queue.join()
|
||||||
|
jobs = queue.list_jobs()
|
||||||
|
print(jobs)
|
||||||
|
assert len(jobs) == 2
|
||||||
|
jobs_dict = {str(x.source): x for x in jobs}
|
||||||
|
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
|
||||||
|
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
|
||||||
|
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
|
||||||
|
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||||
|
event_bus = TestEventService()
|
||||||
|
|
||||||
|
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||||
|
queue.start()
|
||||||
|
queue.download(
|
||||||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
queue.join()
|
||||||
|
events = event_bus.events
|
||||||
|
assert len(events) == 3
|
||||||
|
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
||||||
|
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
||||||
|
assert events[0].event_name == "download_started"
|
||||||
|
assert events[1].event_name == "download_progress"
|
||||||
|
assert events[1].payload["total_bytes"] > 0
|
||||||
|
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
||||||
|
assert events[2].event_name == "download_complete"
|
||||||
|
assert events[2].payload["total_bytes"] == 32029
|
||||||
|
|
||||||
|
# test a failure
|
||||||
|
event_bus.events = [] # reset our accumulator
|
||||||
|
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
|
||||||
|
queue.join()
|
||||||
|
events = event_bus.events
|
||||||
|
print("\n".join([x.model_dump_json() for x in events]))
|
||||||
|
assert len(events) == 1
|
||||||
|
assert events[0].event_name == "download_error"
|
||||||
|
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
||||||
|
assert events[0].payload["error"] is not None
|
||||||
|
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
|
||||||
|
callback_ran = False
|
||||||
|
|
||||||
|
def broken_callback(job: DownloadJob) -> None:
|
||||||
|
nonlocal callback_ran
|
||||||
|
callback_ran = True
|
||||||
|
print(1 / 0) # deliberate error here
|
||||||
|
|
||||||
|
job = queue.download(
|
||||||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
|
dest=tmp_path,
|
||||||
|
on_progress=broken_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
queue.join()
|
||||||
|
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
|
||||||
|
assert Path(tmp_path, "mock12345.safetensors").exists()
|
||||||
|
assert callback_ran
|
||||||
|
# LS: The pytest capsys fixture does not seem to be working. I can see the
|
||||||
|
# correct stderr message in the pytest log, but it is not appearing in
|
||||||
|
# capsys.readouterr().
|
||||||
|
# captured = capsys.readouterr()
|
||||||
|
# assert re.search("division by zero", captured.err)
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
|
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||||
|
event_bus = TestEventService()
|
||||||
|
|
||||||
|
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||||
|
queue.start()
|
||||||
|
|
||||||
|
cancelled = False
|
||||||
|
|
||||||
|
def slow_callback(job: DownloadJob) -> None:
|
||||||
|
time.sleep(2)
|
||||||
|
|
||||||
|
def cancelled_callback(job: DownloadJob) -> None:
|
||||||
|
nonlocal cancelled
|
||||||
|
cancelled = True
|
||||||
|
|
||||||
|
def handler(signum, frame):
|
||||||
|
raise TimeoutError("Join took too long to return")
|
||||||
|
|
||||||
|
job = queue.download(
|
||||||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
|
dest=tmp_path,
|
||||||
|
on_start=slow_callback,
|
||||||
|
on_cancelled=cancelled_callback,
|
||||||
|
)
|
||||||
|
queue.cancel_job(job)
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus.CANCELLED
|
||||||
|
assert cancelled
|
||||||
|
events = event_bus.events
|
||||||
|
assert events[-1].event_name == "download_cancelled"
|
||||||
|
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=10, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
@ -139,7 +294,7 @@ def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
|||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=15, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
|
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
|
||||||
event_bus = TestEventService()
|
event_bus = TestEventService()
|
||||||
|
|
||||||
@ -172,163 +327,6 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
|
|||||||
assert "download_cancelled" in [x.event_name for x in events]
|
assert "download_cancelled" in [x.event_name for x in events]
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
|
||||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
|
||||||
events = set()
|
|
||||||
|
|
||||||
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
|
||||||
events.add(job.status)
|
|
||||||
|
|
||||||
queue = DownloadQueueService(
|
|
||||||
requests_session=session,
|
|
||||||
)
|
|
||||||
queue.start()
|
|
||||||
job = queue.download(
|
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
||||||
dest=tmp_path,
|
|
||||||
on_start=event_handler,
|
|
||||||
on_progress=event_handler,
|
|
||||||
on_complete=event_handler,
|
|
||||||
on_error=event_handler,
|
|
||||||
)
|
|
||||||
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
|
|
||||||
assert isinstance(job.id, int), "expected the job id to be numeric"
|
|
||||||
queue.join()
|
|
||||||
|
|
||||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
|
||||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
|
||||||
|
|
||||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
|
||||||
queue.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
|
||||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
|
||||||
queue = DownloadQueueService(
|
|
||||||
requests_session=session,
|
|
||||||
)
|
|
||||||
queue.start()
|
|
||||||
|
|
||||||
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
|
|
||||||
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
|
|
||||||
|
|
||||||
queue.join()
|
|
||||||
jobs = queue.list_jobs()
|
|
||||||
print(jobs)
|
|
||||||
assert len(jobs) == 2
|
|
||||||
jobs_dict = {str(x.source): x for x in jobs}
|
|
||||||
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
|
|
||||||
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
|
|
||||||
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
|
|
||||||
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
|
|
||||||
queue.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
|
||||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
|
||||||
event_bus = TestEventService()
|
|
||||||
|
|
||||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
|
||||||
queue.start()
|
|
||||||
queue.download(
|
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
||||||
dest=tmp_path,
|
|
||||||
)
|
|
||||||
queue.join()
|
|
||||||
events = event_bus.events
|
|
||||||
assert len(events) == 3
|
|
||||||
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
|
||||||
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
|
||||||
assert events[0].event_name == "download_started"
|
|
||||||
assert events[1].event_name == "download_progress"
|
|
||||||
assert events[1].payload["total_bytes"] > 0
|
|
||||||
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
|
||||||
assert events[2].event_name == "download_complete"
|
|
||||||
assert events[2].payload["total_bytes"] == 32029
|
|
||||||
|
|
||||||
# test a failure
|
|
||||||
event_bus.events = [] # reset our accumulator
|
|
||||||
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
|
|
||||||
queue.join()
|
|
||||||
events = event_bus.events
|
|
||||||
print("\n".join([x.model_dump_json() for x in events]))
|
|
||||||
assert len(events) == 1
|
|
||||||
assert events[0].event_name == "download_error"
|
|
||||||
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
|
||||||
assert events[0].payload["error"] is not None
|
|
||||||
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
|
||||||
queue.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
|
||||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
|
||||||
queue = DownloadQueueService(
|
|
||||||
requests_session=session,
|
|
||||||
)
|
|
||||||
queue.start()
|
|
||||||
|
|
||||||
callback_ran = False
|
|
||||||
|
|
||||||
def broken_callback(job: DownloadJob) -> None:
|
|
||||||
nonlocal callback_ran
|
|
||||||
callback_ran = True
|
|
||||||
print(1 / 0) # deliberate error here
|
|
||||||
|
|
||||||
job = queue.download(
|
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
||||||
dest=tmp_path,
|
|
||||||
on_progress=broken_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
queue.join()
|
|
||||||
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
|
|
||||||
assert Path(tmp_path, "mock12345.safetensors").exists()
|
|
||||||
assert callback_ran
|
|
||||||
# LS: The pytest capsys fixture does not seem to be working. I can see the
|
|
||||||
# correct stderr message in the pytest log, but it is not appearing in
|
|
||||||
# capsys.readouterr().
|
|
||||||
# captured = capsys.readouterr()
|
|
||||||
# assert re.search("division by zero", captured.err)
|
|
||||||
queue.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=15, method="thread")
|
|
||||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
|
||||||
event_bus = TestEventService()
|
|
||||||
|
|
||||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
|
||||||
queue.start()
|
|
||||||
|
|
||||||
cancelled = False
|
|
||||||
|
|
||||||
def slow_callback(job: DownloadJob) -> None:
|
|
||||||
time.sleep(2)
|
|
||||||
|
|
||||||
def cancelled_callback(job: DownloadJob) -> None:
|
|
||||||
nonlocal cancelled
|
|
||||||
cancelled = True
|
|
||||||
|
|
||||||
def handler(signum, frame):
|
|
||||||
raise TimeoutError("Join took too long to return")
|
|
||||||
|
|
||||||
job = queue.download(
|
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
||||||
dest=tmp_path,
|
|
||||||
on_start=slow_callback,
|
|
||||||
on_cancelled=cancelled_callback,
|
|
||||||
)
|
|
||||||
queue.cancel_job(job)
|
|
||||||
queue.join()
|
|
||||||
|
|
||||||
assert job.status == DownloadJobStatus.CANCELLED
|
|
||||||
assert cancelled
|
|
||||||
events = event_bus.events
|
|
||||||
assert events[-1].event_name == "download_cancelled"
|
|
||||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
|
||||||
queue.stop()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def clear_config() -> Generator[None, None, None]:
|
def clear_config() -> Generator[None, None, None]:
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user