mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
4536e4a8b6
* add basic functionality for model metadata fetching from hf and civitai * add storage * start unit tests * add unit tests and documentation * add missing dependency for pytests * remove redundant fetch; add modified/published dates; updated docs * add code to select diffusers files based on the variant type * implement Civitai installs * make huggingface parallel downloading work * add unit tests for model installation manager - Fixed race condition on selection of download destination path - Add fixtures common to several model_manager_2 unit tests - Added dummy model files for testing diffusers and safetensors downloading/probing - Refactored code for selecting proper variant from list of huggingface repo files - Regrouped ordering of methods in model_install_default.py * improve Civitai model downloading - Provide a better error message when Civitai requires an access token (doesn't give a 403 forbidden, but redirects to the HTML of an authorization page -- arrgh) - Handle case of Civitai providing a primary download link plus additional links for VAEs, config files, etc * add routes for retrieving metadata and tags * code tidying and documentation * fix ruff errors * add file needed to maintain test root diretory in repo for unit tests * fix self->cls in classmethod * add pydantic plugin for mypy * use TestSession instead of requests.Session to prevent any internet activity improve logging fix error message formatting fix logging again fix forward vs reverse slash issue in Windows install tests * Several fixes of problems detected during PR review: - Implement cancel_model_install_job and get_model_install_job routes to allow for better control of model download and install. - Fix thread deadlock that occurred after cancelling an install. - Remove unneeded pytest_plugins section from tests/conftest.py - Remove unused _in_terminal_state() from model_install_default. - Remove outdated documentation from several spots. - Add workaround for Civitai API results which don't return correct URL for the default model. * fix docs and tests to match get_job_by_source() rather than get_job() * Update invokeai/backend/model_manager/metadata/fetch/huggingface.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Call CivitaiMetadata.model_validate_json() directly Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * Second round of revisions suggested by @ryanjdick: - Fix type mismatch in `list_all_metadata()` route. - Do not have a default value for the model install job id - Remove static class variable declarations from non Pydantic classes - Change `id` field to `model_id` for the sqlite3 `model_tags` table. - Changed AFTER DELETE triggers to ON DELETE CASCADE for the metadata and tags tables. - Made the `id` field of the `model_metadata` table into a primary key to achieve uniqueness. * Code cleanup suggested in PR review: - Narrowed the declaration of the `parts` attribute of the download progress event - Removed auto-conversion of str to Url in Url-containing sources - Fixed handling of `InvalidModelConfigException` - Made unknown sources raise `NotImplementedError` rather than `Exception` - Improved status reporting on cached HuggingFace access tokens * Multiple fixes: - `job.total_size` returns a valid size for locally installed models - new route `list_models` returns a paged summary of model, name, description, tags and other essential info - fix a few type errors * consolidated all invokeai root pytest fixtures into a single location * Update invokeai/backend/model_manager/metadata/metadata_store.py Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com> * Small tweaks in response to review comments: - Remove flake8 configuration from pyproject.toml - Use `id` rather than `modelId` for huggingface `ModelInfo` object - Use `last_modified` rather than `LastModified` for huggingface `ModelInfo` object - Add `sha256` field to file metadata downloaded from huggingface - Add `Invoker` argument to the model installer `start()` and `stop()` routines (but made it optional in order to facilitate use of the service outside the API) - Removed redundant `PRAGMA foreign_keys` from metadata store initialization code. * Additional tweaks and minor bug fixes - Fix calculation of aggregate diffusers model size to only count the size of files, not files + directories (which gives different unit test results on different filesystems). - Refactor _get_metadata() and _get_download_urls() to have distinct code paths for Civitai, HuggingFace and URL sources. - Forward the `inplace` flag from the source to the job and added unit test for this. - Attach cached model metadata to the job rather than to the model install service. * fix unit test that was breaking on windows due to CR/LF changing size of test json files * fix ruff formatting * a few last minor fixes before merging: - Turn job `error` and `error_type` into properties derived from the exception. - Add TODO comment about the reason for handling temporary directory destruction manually rather than using tempfile.tmpdir(). * add unit tests for reporting HTTP download errors --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
223 lines
7.2 KiB
Python
223 lines
7.2 KiB
Python
"""Test the queued download facility"""
|
|
import re
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
from pydantic.networks import AnyHttpUrl
|
|
from requests.sessions import Session
|
|
from requests_testadapter import TestAdapter, TestSession
|
|
|
|
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
|
from invokeai.app.services.events.events_base import EventServiceBase
|
|
|
|
# Prevent pytest deprecation warnings
|
|
TestAdapter.__test__ = False
|
|
|
|
|
|
@pytest.fixture
|
|
def session() -> Session:
|
|
sess = TestSession()
|
|
for i in ["12345", "9999", "54321"]:
|
|
content = (
|
|
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
|
) # for pause tests, must make content large
|
|
sess.mount(
|
|
f"http://www.civitai.com/models/{i}",
|
|
TestAdapter(
|
|
content,
|
|
headers={
|
|
"Content-Length": len(content),
|
|
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
|
},
|
|
),
|
|
)
|
|
|
|
# here are some malformed URLs to test
|
|
# missing the content length
|
|
sess.mount(
|
|
"http://www.civitai.com/models/missing",
|
|
TestAdapter(
|
|
b"Missing content length",
|
|
headers={
|
|
"Content-Disposition": 'filename="missing.txt"',
|
|
},
|
|
),
|
|
)
|
|
# not found test
|
|
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
|
|
|
return sess
|
|
|
|
|
|
class DummyEvent(BaseModel):
|
|
"""Dummy Event to use with Dummy Event service."""
|
|
|
|
event_name: str
|
|
payload: Dict[str, Any]
|
|
|
|
|
|
# A dummy event service for testing event issuing
|
|
class DummyEventService(EventServiceBase):
|
|
"""Dummy event service for testing."""
|
|
|
|
events: List[DummyEvent]
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.events = []
|
|
|
|
def dispatch(self, event_name: str, payload: Any) -> None:
|
|
"""Dispatch an event by appending it to self.events."""
|
|
self.events.append(DummyEvent(event_name=payload["event"], payload=payload["data"]))
|
|
|
|
|
|
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
|
events = set()
|
|
|
|
def event_handler(job: DownloadJob) -> None:
|
|
events.add(job.status)
|
|
|
|
queue = DownloadQueueService(
|
|
requests_session=session,
|
|
)
|
|
queue.start()
|
|
job = queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
on_start=event_handler,
|
|
on_progress=event_handler,
|
|
on_complete=event_handler,
|
|
on_error=event_handler,
|
|
)
|
|
assert isinstance(job, DownloadJob), "expected the job to be of type DownloadJobBase"
|
|
assert isinstance(job.id, int), "expected the job id to be numeric"
|
|
queue.join()
|
|
|
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
|
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
|
|
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
|
queue.stop()
|
|
|
|
|
|
def test_errors(tmp_path: Path, session: Session) -> None:
|
|
queue = DownloadQueueService(
|
|
requests_session=session,
|
|
)
|
|
queue.start()
|
|
|
|
for bad_url in ["http://www.civitai.com/models/broken", "http://www.civitai.com/models/missing"]:
|
|
queue.download(AnyHttpUrl(bad_url), dest=tmp_path)
|
|
|
|
queue.join()
|
|
jobs = queue.list_jobs()
|
|
print(jobs)
|
|
assert len(jobs) == 2
|
|
jobs_dict = {str(x.source): x for x in jobs}
|
|
assert jobs_dict["http://www.civitai.com/models/broken"].status == DownloadJobStatus.ERROR
|
|
assert jobs_dict["http://www.civitai.com/models/broken"].error_type == "HTTPError(NOT FOUND)"
|
|
assert jobs_dict["http://www.civitai.com/models/missing"].status == DownloadJobStatus.COMPLETED
|
|
assert jobs_dict["http://www.civitai.com/models/missing"].total_bytes == 0
|
|
queue.stop()
|
|
|
|
|
|
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
|
event_bus = DummyEventService()
|
|
|
|
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
|
queue.start()
|
|
queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
)
|
|
queue.join()
|
|
events = event_bus.events
|
|
assert len(events) == 3
|
|
assert events[0].payload["timestamp"] <= events[1].payload["timestamp"]
|
|
assert events[1].payload["timestamp"] <= events[2].payload["timestamp"]
|
|
assert events[0].event_name == "download_started"
|
|
assert events[1].event_name == "download_progress"
|
|
assert events[1].payload["total_bytes"] > 0
|
|
assert events[1].payload["current_bytes"] <= events[1].payload["total_bytes"]
|
|
assert events[2].event_name == "download_complete"
|
|
assert events[2].payload["total_bytes"] == 32029
|
|
|
|
# test a failure
|
|
event_bus.events = [] # reset our accumulator
|
|
queue.download(source=AnyHttpUrl("http://www.civitai.com/models/broken"), dest=tmp_path)
|
|
queue.join()
|
|
events = event_bus.events
|
|
print("\n".join([x.model_dump_json() for x in events]))
|
|
assert len(events) == 1
|
|
assert events[0].event_name == "download_error"
|
|
assert events[0].payload["error_type"] == "HTTPError(NOT FOUND)"
|
|
assert events[0].payload["error"] is not None
|
|
assert re.search(r"requests.exceptions.HTTPError: NOT FOUND", events[0].payload["error"])
|
|
queue.stop()
|
|
|
|
|
|
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
|
queue = DownloadQueueService(
|
|
requests_session=session,
|
|
)
|
|
queue.start()
|
|
|
|
callback_ran = False
|
|
|
|
def broken_callback(job: DownloadJob) -> None:
|
|
nonlocal callback_ran
|
|
callback_ran = True
|
|
print(1 / 0) # deliberate error here
|
|
|
|
job = queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
on_progress=broken_callback,
|
|
)
|
|
|
|
queue.join()
|
|
assert job.status == DownloadJobStatus.COMPLETED # should complete even though the callback is borked
|
|
assert Path(tmp_path, "mock12345.safetensors").exists()
|
|
assert callback_ran
|
|
# LS: The pytest capsys fixture does not seem to be working. I can see the
|
|
# correct stderr message in the pytest log, but it is not appearing in
|
|
# capsys.readouterr().
|
|
# captured = capsys.readouterr()
|
|
# assert re.search("division by zero", captured.err)
|
|
queue.stop()
|
|
|
|
|
|
def test_cancel(tmp_path: Path, session: Session) -> None:
|
|
event_bus = DummyEventService()
|
|
|
|
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
|
queue.start()
|
|
|
|
cancelled = False
|
|
|
|
def slow_callback(job: DownloadJob) -> None:
|
|
time.sleep(2)
|
|
|
|
def cancelled_callback(job: DownloadJob) -> None:
|
|
nonlocal cancelled
|
|
cancelled = True
|
|
|
|
job = queue.download(
|
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
|
dest=tmp_path,
|
|
on_start=slow_callback,
|
|
on_cancelled=cancelled_callback,
|
|
)
|
|
queue.cancel_job(job)
|
|
queue.join()
|
|
|
|
assert job.status == DownloadJobStatus.CANCELLED
|
|
assert cancelled
|
|
events = event_bus.events
|
|
assert events[-1].event_name == "download_cancelled"
|
|
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
|
queue.stop()
|