mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
364 lines
12 KiB
Python
364 lines
12 KiB
Python
"""Test the queued download facility"""
|
|
|
|
import re
|
|
import time
|
|
from contextlib import contextmanager
|
|
from pathlib import Path
|
|
from typing import Generator, Optional
|
|
|
|
import pytest
|
|
from pydantic.networks import AnyHttpUrl
|
|
from requests.sessions import Session
|
|
from requests_testadapter import TestAdapter, TestSession
|
|
|
|
from invokeai.app.services.config import get_config
|
|
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
|
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
|
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile
|
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
|
from tests.test_nodes import TestEventService
|
|
|
|
# Prevent pytest deprecation warnings
|
|
TestAdapter.__test__ = False # type: ignore
|
|
|
|
|
|
@pytest.fixture
|
|
def session() -> Session:
|
|
sess = TestSession()
|
|
for i in ["12345", "9999", "54321"]:
|
|
content = (
|
|
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
|
) # for pause tests, must make content large
|
|
sess.mount(
|
|
f"http://www.civitai.com/models/{i}",
|
|
TestAdapter(
|
|
content,
|
|
headers={
|
|
"Content-Length": len(content),
|
|
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
|
},
|
|
),
|
|
)
|
|
|
|
sess.mount(
|
|
"http://www.huggingface.co/foo.txt",
|
|
TestAdapter(
|
|
content,
|
|
headers={
|
|
"Content-Length": len(content),
|
|
"Content-Disposition": 'filename="foo.safetensors"',
|
|
},
|
|
),
|
|
)
|
|
|
|
# here are some malformed URLs to test
|
|
# missing the content length
|
|
sess.mount(
|
|
"http://www.civitai.com/models/missing",
|
|
TestAdapter(
|
|
b"Missing content length",
|
|
headers={
|
|
"Content-Disposition": 'filename="missing.txt"',
|
|
},
|
|
),
|
|
)
|
|
# not found test
|
|
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
|
|
|
return sess
|
|
|
|
|
|
@pytest.mark.timeout(timeout=10, method="thread")
|
|
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
|
events = set()
|
|
|
|
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
|
print(f"bytes = {job.bytes}")
|
|
events.add(job.status)
|
|
|
|
queue = DownloadQueueService(
|
|
requests_session=mm2_session,
|
|
)
|
|
queue.start()
|
|
job = queue.multifile_download(
|
|
parts=metadata.download_urls(session=mm2_session),
|
|
dest=tmp_path,
|
|
on_start=event_handler,
|
|
on_progress=event_handler,
|
|
on_complete=event_handler,
|
|
on_error=event_handler,
|
|
)
|
|
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
|
queue.join()
|
|
|
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
|
assert job.bytes > 0, "expected download bytes to be positive"
|
|
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
|
assert Path(
|
|
tmp_path, "sdxl-turbo/model_index.json"
|
|
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
|
assert Path(
|
|
tmp_path, "sdxl-turbo/text_encoder/config.json"
|
|
).exists(), f"expected {tmp_path}/sdxl-turbo/text_encoder/config.json to exist"
|
|
|
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
|
queue.stop()
|
|
|
|
|
|
@pytest.mark.timeout(timeout=10, method="thread")
|
|
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
|
events = set()
|
|
|
|
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
|
events.add(job.status)
|
|
|
|
queue = DownloadQueueService(
|
|
requests_session=mm2_session,
|
|
)
|
|
queue.start()
|
|
files = metadata.download_urls(session=mm2_session)
|
|
# this will give a 404 error
|
|
files.append(RemoteModelFile(url="https://test.com/missing_model.safetensors", path=Path("sdxl-turbo/broken")))
|
|
job = queue.multifile_download(
|
|
parts=files,
|
|
dest=tmp_path,
|
|
on_start=event_handler,
|
|
on_progress=event_handler,
|
|
on_complete=event_handler,
|
|
on_error=event_handler,
|
|
)
|
|
queue.join()
|
|
|
|
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
|
assert "HTTPError(NOT FOUND)" in job.error_type
|
|
assert DownloadJobStatus.ERROR in events
|
|
queue.stop()
|
|
|
|
|
|
@pytest.mark.timeout(timeout=15, method="thread")
|
|
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
|
|
event_bus = TestEventService()
|
|
|
|
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
|
queue.start()
|
|
|
|
cancelled = False
|
|
|
|
def cancelled_callback(job: DownloadJob) -> None:
|
|
nonlocal cancelled
|
|
cancelled = True
|
|
|
|
def handler(signum, frame):
|
|
raise TimeoutError("Join took too long to return")
|
|
|
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
|
|
|
job = queue.multifile_download(
|
|
parts=metadata.download_urls(session=mm2_session),
|
|
dest=tmp_path,
|
|
on_cancelled=cancelled_callback,
|
|
)
|
|
queue.cancel_job(job)
|
|
queue.join()
|
|
|
|
assert job.status == DownloadJobStatus.CANCELLED
|
|
assert cancelled
|
|
events = event_bus.events
|
|
assert "download_cancelled" in [x.event_name for x in events]
|
|
queue.stop()
|
|
|
|
|
|
@pytest.mark.timeout(timeout=20, method="thread")
|
|
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
|
events = set()
|
|
|
|
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
|
events.add(job.status)
|
|
|
|
queue = DownloadQueueService(
|
|
requests_session=session,
|
|
)
|
|
queue.start()
|
|
job = queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
on_start=event_handler,
|
|
on_progress=event_handler,
|
|
on_complete=event_handler,
|
|
on_error=event_handler,
|
|
)
|
|
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
|
|
assert isinstance(job.id, int), "expected the job id to be numeric"
|
|
queue.join()
|
|
|
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
|
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
|
|
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
|
queue.stop()
|
|
|
|
|
|
@pytest.mark.timeout(timeout=20, method="thread")
|
|
def test_errors(tmp_path: Path, session: Session) -> None:
|
|
queue = DownloadQueueService(
|
|
requests_session=session,
|
|
)
|
|
queue.start()
|
|
|
|
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
|
|
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
|
|
|
|
queue.join()
|
|
jobs = queue.list_jobs()
|
|
print(jobs)
|
|
assert len(jobs) == 2
|
|
jobs_dict = {str(x.source): x for x in jobs}
|
|
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
|
|
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
|
|
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
|
|
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
|
|
queue.stop()
|
|
|
|
|
|
@pytest.mark.timeout(timeout=20, method="thread")
|
|
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
|
event_bus = TestEventService()
|
|
|
|
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
|
queue.start()
|
|
queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
)
|
|
queue.join()
|
|
events = event_bus.events
|
|
assert len(events) == 3
|
|
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
|
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
|
assert events[0].event_name == "download_started"
|
|
assert events[1].event_name == "download_progress"
|
|
assert events[1].payload["total_bytes"] > 0
|
|
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
|
assert events[2].event_name == "download_complete"
|
|
assert events[2].payload["total_bytes"] == 32029
|
|
|
|
# test a failure
|
|
event_bus.events = [] # reset our accumulator
|
|
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
|
|
queue.join()
|
|
events = event_bus.events
|
|
print("\n".join([x.model_dump_json() for x in events]))
|
|
assert len(events) == 1
|
|
assert events[0].event_name == "download_error"
|
|
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
|
assert events[0].payload["error"] is not None
|
|
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
|
queue.stop()
|
|
|
|
|
|
@pytest.mark.timeout(timeout=20, method="thread")
|
|
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
|
queue = DownloadQueueService(
|
|
requests_session=session,
|
|
)
|
|
queue.start()
|
|
|
|
callback_ran = False
|
|
|
|
def broken_callback(job: DownloadJob) -> None:
|
|
nonlocal callback_ran
|
|
callback_ran = True
|
|
print(1 / 0) # deliberate error here
|
|
|
|
job = queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
on_progress=broken_callback,
|
|
)
|
|
|
|
queue.join()
|
|
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
|
|
assert Path(tmp_path, "mock12345.safetensors").exists()
|
|
assert callback_ran
|
|
# LS: The pytest capsys fixture does not seem to be working. I can see the
|
|
# correct stderr message in the pytest log, but it is not appearing in
|
|
# capsys.readouterr().
|
|
# captured = capsys.readouterr()
|
|
# assert re.search("division by zero", captured.err)
|
|
queue.stop()
|
|
|
|
|
|
@pytest.mark.timeout(timeout=15, method="thread")
|
|
def test_cancel(tmp_path: Path, session: Session) -> None:
|
|
event_bus = TestEventService()
|
|
|
|
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
|
queue.start()
|
|
|
|
cancelled = False
|
|
|
|
def slow_callback(job: DownloadJob) -> None:
|
|
time.sleep(2)
|
|
|
|
def cancelled_callback(job: DownloadJob) -> None:
|
|
nonlocal cancelled
|
|
cancelled = True
|
|
|
|
def handler(signum, frame):
|
|
raise TimeoutError("Join took too long to return")
|
|
|
|
job = queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
on_start=slow_callback,
|
|
on_cancelled=cancelled_callback,
|
|
)
|
|
queue.cancel_job(job)
|
|
queue.join()
|
|
|
|
assert job.status == DownloadJobStatus.CANCELLED
|
|
assert cancelled
|
|
events = event_bus.events
|
|
assert events[-1].event_name == "download_cancelled"
|
|
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
|
queue.stop()
|
|
|
|
|
|
@contextmanager
|
|
def clear_config() -> Generator[None, None, None]:
|
|
try:
|
|
yield None
|
|
finally:
|
|
get_config.cache_clear()
|
|
|
|
|
|
def test_tokens(tmp_path: Path, session: Session):
|
|
with clear_config():
|
|
config = get_config()
|
|
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
|
queue = DownloadQueueService(requests_session=session)
|
|
queue.start()
|
|
# this one has an access token assigned
|
|
job1 = queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
)
|
|
# this one doesn't
|
|
job2 = queue.download(
|
|
source=AnyHttpUrl(
|
|
"http://www.huggingface.co/foo.txt",
|
|
),
|
|
dest=tmp_path,
|
|
)
|
|
queue.join()
|
|
# this token is defined in the temporary root invokeai.yaml
|
|
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
|
|
assert job1.access_token == "cv_12345"
|
|
assert job2.access_token is None
|
|
queue.stop()
|