mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
224 lines
7.3 KiB
Python
224 lines
7.3 KiB
Python
|
"""Test the queued download facility"""
|
||
|
import re
|
||
|
import time
|
||
|
from pathlib import Path
|
||
|
from typing import Any, Dict, List
|
||
|
|
||
|
import pytest
|
||
|
import requests
|
||
|
from pydantic import BaseModel
|
||
|
from pydantic.networks import AnyHttpUrl
|
||
|
from requests.sessions import Session
|
||
|
from requests_testadapter import TestAdapter
|
||
|
|
||
|
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||
|
|
||
|
# Prevent pytest deprecation warnings
|
||
|
TestAdapter.__test__ = False
|
||
|
|
||
|
|
||
|
@pytest.fixture
|
||
|
def session() -> requests.sessions.Session:
|
||
|
sess = requests.Session()
|
||
|
for i in ["12345", "9999", "54321"]:
|
||
|
content = (
|
||
|
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||
|
) # for pause tests, must make content large
|
||
|
sess.mount(
|
||
|
f"http://www.civitai.com/models/{i}",
|
||
|
TestAdapter(
|
||
|
content,
|
||
|
headers={
|
||
|
"Content-Length": len(content),
|
||
|
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||
|
},
|
||
|
),
|
||
|
)
|
||
|
|
||
|
# here are some malformed URLs to test
|
||
|
# missing the content length
|
||
|
sess.mount(
|
||
|
"http://www.civitai.com/models/missing",
|
||
|
TestAdapter(
|
||
|
b"Missing content length",
|
||
|
headers={
|
||
|
"Content-Disposition": 'filename="missing.txt"',
|
||
|
},
|
||
|
),
|
||
|
)
|
||
|
# not found test
|
||
|
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||
|
|
||
|
return sess
|
||
|
|
||
|
|
||
|
class DummyEvent(BaseModel):
|
||
|
"""Dummy Event to use with Dummy Event service."""
|
||
|
|
||
|
event_name: str
|
||
|
payload: Dict[str, Any]
|
||
|
|
||
|
|
||
|
# A dummy event service for testing event issuing
|
||
|
class DummyEventService(EventServiceBase):
|
||
|
"""Dummy event service for testing."""
|
||
|
|
||
|
events: List[DummyEvent]
|
||
|
|
||
|
def __init__(self) -> None:
|
||
|
super().__init__()
|
||
|
self.events = []
|
||
|
|
||
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
||
|
"""Dispatch an event by appending it to self.events."""
|
||
|
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
||
|
|
||
|
|
||
|
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||
|
events = set()
|
||
|
|
||
|
def event_handler(job: DownloadJob) -> None:
|
||
|
events.add(job.status)
|
||
|
|
||
|
queue = DownloadQueueService(
|
||
|
requests_session=session,
|
||
|
)
|
||
|
queue.start()
|
||
|
job = queue.download(
|
||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||
|
dest=tmp_path,
|
||
|
on_start=event_handler,
|
||
|
on_progress=event_handler,
|
||
|
on_complete=event_handler,
|
||
|
on_error=event_handler,
|
||
|
)
|
||
|
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
|
||
|
assert isinstance(job.id, int), "expected the job id to be numeric"
|
||
|
queue.join()
|
||
|
|
||
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||
|
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||
|
|
||
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||
|
queue.stop()
|
||
|
|
||
|
|
||
|
def test_errors(tmp_path: Path, session: Session) -> None:
|
||
|
queue = DownloadQueueService(
|
||
|
requests_session=session,
|
||
|
)
|
||
|
queue.start()
|
||
|
|
||
|
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
|
||
|
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
|
||
|
|
||
|
queue.join()
|
||
|
jobs = queue.list_jobs()
|
||
|
print(jobs)
|
||
|
assert len(jobs) == 2
|
||
|
jobs_dict = {str(x.source): x for x in jobs}
|
||
|
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
|
||
|
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
|
||
|
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
|
||
|
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
|
||
|
queue.stop()
|
||
|
|
||
|
|
||
|
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||
|
event_bus = DummyEventService()
|
||
|
|
||
|
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||
|
queue.start()
|
||
|
queue.download(
|
||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||
|
dest=tmp_path,
|
||
|
)
|
||
|
queue.join()
|
||
|
events = event_bus.events
|
||
|
assert len(events) == 3
|
||
|
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
||
|
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
||
|
assert events[0].event_name == "download_started"
|
||
|
assert events[1].event_name == "download_progress"
|
||
|
assert events[1].payload["total_bytes"] > 0
|
||
|
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
||
|
assert events[2].event_name == "download_complete"
|
||
|
assert events[2].payload["total_bytes"] == 32029
|
||
|
|
||
|
# test a failure
|
||
|
event_bus.events = [] # reset our accumulator
|
||
|
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
|
||
|
queue.join()
|
||
|
events = event_bus.events
|
||
|
print("\n".join([x.model_dump_json() for x in events]))
|
||
|
assert len(events) == 1
|
||
|
assert events[0].event_name == "download_error"
|
||
|
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
||
|
assert events[0].payload["error"] is not None
|
||
|
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
||
|
queue.stop()
|
||
|
|
||
|
|
||
|
def test_broken_callbacks(tmp_path: Path, session: requests.sessions.Session, capsys) -> None:
|
||
|
queue = DownloadQueueService(
|
||
|
requests_session=session,
|
||
|
)
|
||
|
queue.start()
|
||
|
|
||
|
callback_ran = False
|
||
|
|
||
|
def broken_callback(job: DownloadJob) -> None:
|
||
|
nonlocal callback_ran
|
||
|
callback_ran = True
|
||
|
print(1 / 0) # deliberate error here
|
||
|
|
||
|
job = queue.download(
|
||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||
|
dest=tmp_path,
|
||
|
on_progress=broken_callback,
|
||
|
)
|
||
|
|
||
|
queue.join()
|
||
|
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
|
||
|
assert Path(tmp_path, "mock12345.safetensors").exists()
|
||
|
assert callback_ran
|
||
|
# LS: The pytest capsys fixture does not seem to be working. I can see the
|
||
|
# correct stderr message in the pytest log, but it is not appearing in
|
||
|
# capsys.readouterr().
|
||
|
# captured = capsys.readouterr()
|
||
|
# assert re.search("division by zero", captured.err)
|
||
|
queue.stop()
|
||
|
|
||
|
|
||
|
def test_cancel(tmp_path: Path, session: requests.sessions.Session) -> None:
|
||
|
event_bus = DummyEventService()
|
||
|
|
||
|
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||
|
queue.start()
|
||
|
|
||
|
cancelled = False
|
||
|
|
||
|
def slow_callback(job: DownloadJob) -> None:
|
||
|
time.sleep(2)
|
||
|
|
||
|
def cancelled_callback(job: DownloadJob) -> None:
|
||
|
nonlocal cancelled
|
||
|
cancelled = True
|
||
|
|
||
|
job = queue.download(
|
||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||
|
dest=tmp_path,
|
||
|
on_start=slow_callback,
|
||
|
on_cancelled=cancelled_callback,
|
||
|
)
|
||
|
queue.cancel_job(job)
|
||
|
queue.join()
|
||
|
|
||
|
assert job.status == DownloadJobStatus.CANCELLED
|
||
|
assert cancelled
|
||
|
events = event_bus.events
|
||
|
assert events[-1].event_name == "download_cancelled"
|
||
|
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||
|
queue.stop()
|