refactor model_install to work with refactored download queue

This commit is contained in:
Lincoln Stein
2024-05-13 22:49:15 -04:00
parent 287c679f7b
commit f29c406fed
8 changed files with 226 additions and 220 deletions

View File

@ -4,79 +4,33 @@ import re
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional
from typing import Any, Generator, Optional
import pytest
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from requests_testadapter import TestAdapter
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False # type: ignore
@pytest.fixture
def session() -> Session:
sess = TestSession()
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess
TestAdapter.__test__ = False
@pytest.mark.timeout(timeout=10, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
events = set()
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
job = queue.download(
@ -92,6 +46,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
@ -99,9 +54,9 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None:
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
@ -121,10 +76,10 @@ def test_errors(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None:
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
@ -157,9 +112,9 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
@ -189,10 +144,10 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None:
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
cancelled = False
@ -204,9 +159,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
@ -223,14 +175,15 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
print(f"bytes = {job.bytes}")
events.add(job.status)
queue = DownloadQueueService(
@ -251,6 +204,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "sdxl-turbo"
assert Path(
tmp_path, "sdxl-turbo/model_index.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
@ -266,6 +220,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
@ -289,13 +244,14 @@ def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
queue.join()
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
assert job.error_type is not None
assert "HTTPError(NOT FOUND)" in job.error_type
assert DownloadJobStatus.ERROR in events
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
@ -307,11 +263,9 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session),
@ -327,6 +281,29 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
assert "download_cancelled" in [x.event_name for x in events]
queue.stop()
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
job = queue.multifile_download(
parts=[
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
],
dest=tmp_path,
)
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
queue.stop()
@contextmanager
def clear_config() -> Generator[None, None, None]:
try:
@ -335,11 +312,11 @@ def clear_config() -> Generator[None, None, None]:
get_config.cache_clear()
def test_tokens(tmp_path: Path, session: Session):
def test_tokens(tmp_path: Path, mm2_session: Session):
with clear_config():
config = get_config()
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
queue = DownloadQueueService(requests_session=session)
queue = DownloadQueueService(requests_session=mm2_session)
queue.start()
# this one has an access token assigned
job1 = queue.download(

View File

@ -286,14 +286,36 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
# TODO: Test subfolder download
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
job = mm2_installer.download_diffusers_model(source, tmp_path)
mm2_installer.wait_for_installs(timeout=5)
print(job.local_path)
assert job.status == InstallStatus.DOWNLOADS_DONE
assert (tmp_path / "sdxl-turbo").exists()
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
bus = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
job = mm2_installer.import_model(source)
job_list = mm2_installer.wait_for_installs(timeout=10)
assert len(job_list) == 1
assert job.complete
assert job.config_out
key = job.config_out.key
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
@ -327,7 +349,6 @@ def test_other_error_during_install(
assert job.error == "Test error"
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
@pytest.mark.parametrize(
"model_params",
[

View File

@ -317,4 +317,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
},
),
)
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess