refactor model_install to work with refactored download queue

This commit is contained in:
Lincoln Stein 2024-05-13 22:49:15 -04:00
parent 287c679f7b
commit f29c406fed
8 changed files with 226 additions and 220 deletions

View File

@ -42,9 +42,13 @@ MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[E
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
class DownloadJobBase(BaseModel):
"""Base of classes to monitor and control downloads."""
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
@ -149,8 +153,6 @@ class DownloadJob(DownloadJobBase):
# required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process
@ -225,7 +227,7 @@ class DownloadQueueServiceBase(ABC):
@abstractmethod
def multifile_download(
self,
parts: Set[RemoteModelFile],
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
@ -315,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def cancel_job(self, job: DownloadJob) -> None:
def cancel_job(self, job: DownloadJobBase) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass
@ -325,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
pass
@abstractmethod
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed,

View File

@ -113,9 +113,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service."
)
with self._lock:
job.id = self._next_job_id
self._next_job_id += 1
job.id = self._next_id()
job.set_callbacks(
on_start=on_start,
on_progress=on_progress,
@ -161,16 +159,17 @@ class DownloadQueueService(DownloadQueueServiceBase):
def multifile_download(
self,
parts: Set[RemoteModelFile],
parts: List[RemoteModelFile],
dest: Path,
access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest)
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
mfdj.set_callbacks(
on_start=on_start,
on_progress=on_progress,
@ -190,6 +189,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
)
mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj
if submit_job:
self.submit_multifile_download(mfdj)
return mfdj
@ -208,6 +208,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
"""Wait for all jobs to complete."""
self._queue.join()
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs."""
return list(self._jobs.values())
@ -229,7 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJob) -> None:
def cancel_job(self, job: DownloadJobBase) -> None:
"""
Cancel the indicated job.
@ -245,7 +251,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state:
self.cancel_job(job)
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
@ -468,6 +474,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING
assert download_job.download_path is not None
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
mf_job.download_path = (
mf_job.dest / path_relative_to_destdir.parts[0]
) # keep just the first component of the path
self._execute_cb(mf_job, "on_start")
def _mfd_progress(self, download_job: DownloadJob) -> None:

View File

@ -6,14 +6,14 @@ import traceback
from abc import ABC, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase
@ -166,9 +166,6 @@ class ModelInstallJob(BaseModel):
source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source"
)
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception"
)
@ -177,7 +174,7 @@ class ModelInstallJob(BaseModel):
)
# internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_do_install: Optional[bool] = PrivateAttr(default=True)
_download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None:
@ -408,21 +405,6 @@ class ModelInstallServiceBase(ABC):
"""
@abstractmethod
def download_diffusers_model(
self,
source: HFModelSource,
download_to: Path,
) -> ModelInstallJob:
"""
Download, but do not install, a diffusers model.
:param source: An HFModelSource object containing a repo_id
:param download_to: Path to directory that will contain the downloaded model.
Returns: a ModelInstallJob
"""
@abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source."""

View File

@ -9,7 +9,7 @@ from pathlib import Path
from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union
import torch
import yaml
@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
@ -89,7 +89,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._download_cache: Dict[int, ModelInstallJob] = {}
self._running = False
self._session = session
self._install_thread: Optional[threading.Thread] = None
@ -249,9 +249,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_jobs.append(install_job)
return install_job
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
return self._import_from_hf(source, download_path=download_to)
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs
@ -291,8 +288,9 @@ class ModelInstallService(ModelInstallServiceBase):
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
job.cancel()
with self._lock:
self._cancel_download_parts(job)
self._logger.warning(f"Cancelling {job.source}")
if dj := job._download_job:
self._download_queue.cancel_job(dj)
def prune_jobs(self) -> None:
"""Prune all completed and errored jobs."""
@ -340,7 +338,7 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_config_path = stanza.get("config")
if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path)
@ -476,16 +474,19 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
job.set_error(
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
download_job = install_job._download_job
if download_job and any(
x.content_type is not None and "text/html" in x.content_type for x in download_job.download_parts
):
install_job.set_error(
InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
)
)
else:
job.set_error(excp)
self._signal_job_errored(job)
install_job.set_error(excp)
self._signal_job_errored(install_job)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory
@ -511,7 +512,6 @@ class ModelInstallService(ModelInstallServiceBase):
This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory.
"""
installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
}
@ -648,7 +648,6 @@ class ModelInstallService(ModelInstallServiceBase):
self,
source: HFModelSource,
config: Optional[Dict[str, Any]] = None,
download_path: Optional[Path] = None,
) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token()
@ -668,7 +667,6 @@ class ModelInstallService(ModelInstallServiceBase):
config=config,
remote_files=remote_files,
metadata=metadata,
download_path=download_path,
)
def _import_from_url(
@ -704,14 +702,10 @@ class ModelInstallService(ModelInstallServiceBase):
remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]],
download_path: Optional[Path] = None, # if defined, download only - don't install!
) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found")
destdir = download_path or Path(
destdir = Path(
mkdtemp(
dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX,
@ -726,6 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
bytes=0,
total_bytes=0,
)
# remember the temporary directory for later removal
install_job._install_tmpdir = destdir
# In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid
# creating unwanted subfolders
@ -739,39 +736,31 @@ class ModelInstallService(ModelInstallServiceBase):
# we remember the path up to the top of the destdir so that it may be
# removed safely at the end of the install process.
install_job._install_tmpdir = destdir
install_job._do_install = download_path is None
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
parts: List[RemoteModelFile] = []
for model_file in remote_files:
url = model_file.url
path = root / model_file.path.relative_to(subfolder)
self._logger.debug(f"Downloading {url} => {path}")
assert install_job.total_bytes is not None
assert model_file.size is not None
install_job.total_bytes += model_file.size
assert hasattr(source, "access_token")
dest = destdir / path.parent
dest.mkdir(parents=True, exist_ok=True)
download_job = DownloadJob(
source=url,
dest=dest,
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
multifile_job = self._download_queue.multifile_download(
parts=parts,
dest=destdir,
access_token=source.access_token,
)
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
submit_job=False,
on_start=self._download_started_callback,
on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback,
on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback,
)
self._download_cache[multifile_job.id] = install_job
install_job._download_job = multifile_job
files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
self._download_queue.submit_multifile_download(multifile_job)
return install_job
def _stat_size(self, path: Path) -> int:
@ -786,86 +775,59 @@ class ModelInstallService(ModelInstallServiceBase):
# ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------
def _download_started_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download started: {download_job.source}")
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
install_job = self._download_cache[download_job.id]
install_job.status = InstallStatus.DOWNLOADING
assert download_job.download_path
if install_job.local_path == install_job._install_tmpdir:
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
dest_name = partial_path.parts[0]
install_job.local_path = install_job._install_tmpdir / dest_name
if install_job.local_path == install_job._install_tmpdir: # first time
install_job.local_path = download_job.download_path
install_job.total_bytes = download_job.total_bytes
# Update the total bytes count for remote sources.
if not install_job.total_bytes:
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
def _download_progress_callback(self, download_job: DownloadJob) -> None:
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
install_job = self._download_cache[download_job.id]
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._cancel_download_parts(install_job)
self._download_queue.cancel_job(download_job)
else:
# update sizes
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None:
self._logger.info(f"Model download complete: {download_job.source}")
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache[download_job.source]
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
install_job = self._download_cache.pop(download_job.id)
self._signal_job_downloads_done(install_job)
if install_job._do_install:
self._put_in_queue(install_job)
self._put_in_queue(install_job) # this starts the installation and registration
# Let other threads know that the number of downloads has changed
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set()
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
install_job = self._download_cache.pop(download_job.id)
assert install_job is not None
assert excp is not None
install_job.set_error(excp)
self._logger.error(
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
)
self._cancel_download_parts(install_job)
self._download_queue.cancel_job(download_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock:
install_job = self._download_cache.pop(download_job.source, None)
install_job = self._download_cache.pop(download_job.id, None)
if not install_job:
return
self._downloads_changed_event.set()
self._logger.warning(f"Model download canceled: {download_job.source}")
# if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored:
install_job.cancel()
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed
self._downloads_changed_event.set()
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
# do not lock here because it gets called within a locked context
for s in install_job.download_parts:
self._download_queue.cancel_job(s)
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus
# ------------------------------------------------------------------------------------------------
@ -877,6 +839,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus:
assert job._download_job is not None
parts: List[Dict[str, str | int]] = [
{
"url": str(x.source),
@ -884,7 +847,7 @@ class ModelInstallService(ModelInstallServiceBase):
"bytes": x.bytes,
"total_bytes": x.total_bytes,
}
for x in job.download_parts
for x in job._download_job.download_parts
]
assert job.bytes is not None
assert job.total_bytes is not None
@ -929,7 +892,13 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
"""
Return a metadata fetcher appropriate for provided url.
This used to be more useful, but the number of supported model
sources has been reduced to HuggingFace alone.
"""
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -40,6 +40,9 @@ class RemoteModelFile(BaseModel):
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
def __hash__(self) -> int:
return hash(str(self))
class ModelMetadataBase(BaseModel):
"""Base class for model metadata information."""

View File

@ -4,79 +4,33 @@ import re
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator, Optional
from typing import Any, Generator, Optional
import pytest
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from requests_testadapter import TestAdapter
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
TestAdapter.__test__ = False # type: ignore
@pytest.fixture
def session() -> Session:
sess = TestSession()
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess
TestAdapter.__test__ = False
@pytest.mark.timeout(timeout=10, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
events = set()
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status)
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
job = queue.download(
@ -92,6 +46,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
@ -99,9 +54,9 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None:
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
@ -121,10 +76,10 @@ def test_errors(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None:
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
@ -157,9 +112,9 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
queue = DownloadQueueService(
requests_session=session,
requests_session=mm2_session,
)
queue.start()
@ -189,10 +144,10 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
@pytest.mark.timeout(timeout=10, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None:
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start()
cancelled = False
@ -204,9 +159,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
@ -223,14 +175,15 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
print(f"bytes = {job.bytes}")
events.add(job.status)
queue = DownloadQueueService(
@ -251,6 +204,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "sdxl-turbo"
assert Path(
tmp_path, "sdxl-turbo/model_index.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
@ -266,6 +220,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
@ -289,13 +244,14 @@ def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
queue.join()
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
assert job.error_type is not None
assert "HTTPError(NOT FOUND)" in job.error_type
assert DownloadJobStatus.ERROR in events
queue.stop()
@pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
event_bus = TestEventService()
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
@ -307,11 +263,9 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
nonlocal cancelled
cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session),
@ -327,6 +281,29 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
assert "download_cancelled" in [x.event_name for x in events]
queue.stop()
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
job = queue.multifile_download(
parts=[
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
],
dest=tmp_path,
)
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
queue.stop()
@contextmanager
def clear_config() -> Generator[None, None, None]:
try:
@ -335,11 +312,11 @@ def clear_config() -> Generator[None, None, None]:
get_config.cache_clear()
def test_tokens(tmp_path: Path, session: Session):
def test_tokens(tmp_path: Path, mm2_session: Session):
with clear_config():
config = get_config()
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
queue = DownloadQueueService(requests_session=session)
queue = DownloadQueueService(requests_session=mm2_session)
queue.start()
# this one has an access token assigned
job1 = queue.download(

View File

@ -286,14 +286,36 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
@pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
# TODO: Test subfolder download
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
job = mm2_installer.download_diffusers_model(source, tmp_path)
mm2_installer.wait_for_installs(timeout=5)
print(job.local_path)
assert job.status == InstallStatus.DOWNLOADS_DONE
assert (tmp_path / "sdxl-turbo").exists()
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
bus = mm2_installer.event_bus
store = mm2_installer.record_store
assert isinstance(bus, EventServiceBase)
assert store is not None
job = mm2_installer.import_model(source)
job_list = mm2_installer.wait_for_installs(timeout=10)
assert len(job_list) == 1
assert job.complete
assert job.config_out
key = job.config_out.key
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
@ -327,7 +349,6 @@ def test_other_error_during_install(
assert job.error == "Test error"
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
@pytest.mark.parametrize(
"model_params",
[

View File

@ -317,4 +317,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
},
),
)
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess