diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index 3e415091c7..4880ab98b8 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -42,9 +42,13 @@ MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[E DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler] DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler] + class DownloadJobBase(BaseModel): """Base of classes to monitor and control downloads.""" + # automatically assigned on creation + id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel + dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path") download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory") status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") @@ -149,8 +153,6 @@ class DownloadJob(DownloadJobBase): # required variables to be passed in on creation source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") - # automatically assigned on creation - id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel priority: int = Field(default=10, description="Queue priority; lower values are higher priority") # set internally during download process @@ -225,7 +227,7 @@ class DownloadQueueServiceBase(ABC): @abstractmethod def multifile_download( self, - parts: Set[RemoteModelFile], + parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, submit_job: bool = True, @@ -315,7 +317,7 @@ class DownloadQueueServiceBase(ABC): pass @abstractmethod - def cancel_job(self, job: DownloadJob) -> None: + def cancel_job(self, job: DownloadJobBase) -> None: """Cancel the job, clearing partial downloads and putting it into ERROR state.""" pass @@ -325,7 +327,7 @@ class DownloadQueueServiceBase(ABC): pass @abstractmethod - def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: """Wait until the indicated download job has reached a terminal state. This will block until the indicated install job has completed, diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 3f55e9a254..4555477004 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -113,18 +113,16 @@ class DownloadQueueService(DownloadQueueServiceBase): raise ServiceInactiveException( "The download service is not currently accepting requests. Please call start() to initialize the service." ) - with self._lock: - job.id = self._next_job_id - self._next_job_id += 1 - job.set_callbacks( - on_start=on_start, - on_progress=on_progress, - on_complete=on_complete, - on_cancelled=on_cancelled, - on_error=on_error, - ) - self._jobs[job.id] = job - self._queue.put(job) + job.id = self._next_id() + job.set_callbacks( + on_start=on_start, + on_progress=on_progress, + on_complete=on_complete, + on_cancelled=on_cancelled, + on_error=on_error, + ) + self._jobs[job.id] = job + self._queue.put(job) def download( self, @@ -161,16 +159,17 @@ class DownloadQueueService(DownloadQueueServiceBase): def multifile_download( self, - parts: Set[RemoteModelFile], + parts: List[RemoteModelFile], dest: Path, access_token: Optional[str] = None, + submit_job: bool = True, on_start: Optional[DownloadEventHandler] = None, on_progress: Optional[DownloadEventHandler] = None, on_complete: Optional[DownloadEventHandler] = None, on_cancelled: Optional[DownloadEventHandler] = None, on_error: Optional[DownloadExceptionHandler] = None, ) -> MultiFileDownloadJob: - mfdj = MultiFileDownloadJob(dest=dest) + mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id()) mfdj.set_callbacks( on_start=on_start, on_progress=on_progress, @@ -190,7 +189,8 @@ class DownloadQueueService(DownloadQueueServiceBase): ) mfdj.download_parts.add(job) self._download_part2parent[job.source] = mfdj - self.submit_multifile_download(mfdj) + if submit_job: + self.submit_multifile_download(mfdj) return mfdj def submit_multifile_download(self, job: MultiFileDownloadJob) -> None: @@ -208,6 +208,12 @@ class DownloadQueueService(DownloadQueueServiceBase): """Wait for all jobs to complete.""" self._queue.join() + def _next_id(self) -> int: + with self._lock: + id = self._next_job_id + self._next_job_id += 1 + return id + def list_jobs(self) -> List[DownloadJob]: """List all the jobs.""" return list(self._jobs.values()) @@ -229,7 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase): except KeyError as excp: raise UnknownJobIDException("Unrecognized job") from excp - def cancel_job(self, job: DownloadJob) -> None: + def cancel_job(self, job: DownloadJobBase) -> None: """ Cancel the indicated job. @@ -245,7 +251,7 @@ class DownloadQueueService(DownloadQueueServiceBase): if not job.in_terminal_state: self.cancel_job(job) - def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob: + def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase: """Block until the indicated job has reached terminal state, or when timeout limit reached.""" start = time.time() while not job.in_terminal_state: @@ -468,6 +474,11 @@ class DownloadQueueService(DownloadQueueServiceBase): if mf_job.waiting: mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) mf_job.status = DownloadJobStatus.RUNNING + assert download_job.download_path is not None + path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest) + mf_job.download_path = ( + mf_job.dest / path_relative_to_destdir.parts[0] + ) # keep just the first component of the path self._execute_cb(mf_job, "on_start") def _mfd_progress(self, download_job: DownloadJob) -> None: diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index ccb8e3772e..68cf9591e0 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -6,14 +6,14 @@ import traceback from abc import ABC, abstractmethod from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Set, Union +from typing import Any, Dict, List, Literal, Optional, Union from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic.networks import AnyHttpUrl from typing_extensions import Annotated from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import ModelRecordServiceBase @@ -166,9 +166,6 @@ class ModelInstallJob(BaseModel): source_metadata: Optional[AnyModelRepoMetadata] = Field( default=None, description="Metadata provided by the model source" ) - download_parts: Set[DownloadJob] = Field( - default_factory=set, description="Download jobs contributing to this install" - ) error: Optional[str] = Field( default=None, description="On an error condition, this field will contain the text of the exception" ) @@ -177,7 +174,7 @@ class ModelInstallJob(BaseModel): ) # internal flags and transitory settings _install_tmpdir: Optional[Path] = PrivateAttr(default=None) - _do_install: Optional[bool] = PrivateAttr(default=True) + _download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None) def set_error(self, e: Exception) -> None: @@ -408,21 +405,6 @@ class ModelInstallServiceBase(ABC): """ - @abstractmethod - def download_diffusers_model( - self, - source: HFModelSource, - download_to: Path, - ) -> ModelInstallJob: - """ - Download, but do not install, a diffusers model. - - :param source: An HFModelSource object containing a repo_id - :param download_to: Path to directory that will contain the downloaded model. - - Returns: a ModelInstallJob - """ - @abstractmethod def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: """Return the ModelInstallJob(s) corresponding to the provided source.""" diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index f59c7b9f85..2ad321c260 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -9,7 +9,7 @@ from pathlib import Path from queue import Empty, Queue from shutil import copyfile, copytree, move, rmtree from tempfile import mkdtemp -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Type, Union import torch import yaml @@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress +from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase @@ -89,7 +89,7 @@ class ModelInstallService(ModelInstallServiceBase): self._downloads_changed_event = threading.Event() self._install_completed_event = threading.Event() self._download_queue = download_queue - self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} + self._download_cache: Dict[int, ModelInstallJob] = {} self._running = False self._session = session self._install_thread: Optional[threading.Thread] = None @@ -249,9 +249,6 @@ class ModelInstallService(ModelInstallServiceBase): self._install_jobs.append(install_job) return install_job - def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob: - return self._import_from_hf(source, download_path=download_to) - def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 return self._install_jobs @@ -291,8 +288,9 @@ class ModelInstallService(ModelInstallServiceBase): def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" job.cancel() - with self._lock: - self._cancel_download_parts(job) + self._logger.warning(f"Cancelling {job.source}") + if dj := job._download_job: + self._download_queue.cancel_job(dj) def prune_jobs(self) -> None: """Prune all completed and errored jobs.""" @@ -340,7 +338,7 @@ class ModelInstallService(ModelInstallServiceBase): legacy_config_path = stanza.get("config") if legacy_config_path: # In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir. - legacy_config_path: Path = self._app_config.root_path / legacy_config_path + legacy_config_path = self._app_config.root_path / legacy_config_path if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path): legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path) config["config_path"] = str(legacy_config_path) @@ -476,16 +474,19 @@ class ModelInstallService(ModelInstallServiceBase): job.config_out = self.record_store.get_model(key) self._signal_job_completed(job) - def _set_error(self, job: ModelInstallJob, excp: Exception) -> None: - if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): - job.set_error( + def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None: + download_job = install_job._download_job + if download_job and any( + x.content_type is not None and "text/html" in x.content_type for x in download_job.download_parts + ): + install_job.set_error( InvalidModelConfigException( - f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." + f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." ) ) else: - job.set_error(excp) - self._signal_job_errored(job) + install_job.set_error(excp) + self._signal_job_errored(install_job) # -------------------------------------------------------------------------------------------- # Internal functions that manage the models directory @@ -511,7 +512,6 @@ class ModelInstallService(ModelInstallServiceBase): This is typically only used during testing with a new DB or when using the memory DB, because those are the only situations in which we may have orphaned models in the models directory. """ - installed_model_paths = { (self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models() } @@ -648,7 +648,6 @@ class ModelInstallService(ModelInstallServiceBase): self, source: HFModelSource, config: Optional[Dict[str, Any]] = None, - download_path: Optional[Path] = None, ) -> ModelInstallJob: # Add user's cached access token to HuggingFace requests source.access_token = source.access_token or HfFolder.get_token() @@ -668,7 +667,6 @@ class ModelInstallService(ModelInstallServiceBase): config=config, remote_files=remote_files, metadata=metadata, - download_path=download_path, ) def _import_from_url( @@ -704,14 +702,10 @@ class ModelInstallService(ModelInstallServiceBase): remote_files: List[RemoteModelFile], metadata: Optional[AnyModelRepoMetadata], config: Optional[Dict[str, Any]], - download_path: Optional[Path] = None, # if defined, download only - don't install! ) -> ModelInstallJob: - # TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up. - # Currently the tmpdir isn't automatically removed at exit because it is - # being held in a daemon thread. if len(remote_files) == 0: raise ValueError(f"{source}: No downloadable files found") - destdir = download_path or Path( + destdir = Path( mkdtemp( dir=self._app_config.models_path, prefix=TMPDIR_PREFIX, @@ -726,6 +720,9 @@ class ModelInstallService(ModelInstallServiceBase): bytes=0, total_bytes=0, ) + # remember the temporary directory for later removal + install_job._install_tmpdir = destdir + # In the event that there is a subfolder specified in the source, # we need to remove it from the destination path in order to avoid # creating unwanted subfolders @@ -739,39 +736,31 @@ class ModelInstallService(ModelInstallServiceBase): # we remember the path up to the top of the destdir so that it may be # removed safely at the end of the install process. install_job._install_tmpdir = destdir - install_job._do_install = download_path is None - assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below + + parts: List[RemoteModelFile] = [] + for model_file in remote_files: + assert install_job.total_bytes is not None + assert model_file.size is not None + install_job.total_bytes += model_file.size + parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder))) + multifile_job = self._download_queue.multifile_download( + parts=parts, + dest=destdir, + access_token=source.access_token, + submit_job=False, + on_start=self._download_started_callback, + on_progress=self._download_progress_callback, + on_complete=self._download_complete_callback, + on_error=self._download_error_callback, + on_cancelled=self._download_cancelled_callback, + ) + self._download_cache[multifile_job.id] = install_job + install_job._download_job = multifile_job files_string = "file" if len(remote_files) == 1 else "file" self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})") self._logger.debug(f"remote_files={remote_files}") - for model_file in remote_files: - url = model_file.url - path = root / model_file.path.relative_to(subfolder) - self._logger.debug(f"Downloading {url} => {path}") - install_job.total_bytes += model_file.size - assert hasattr(source, "access_token") - dest = destdir / path.parent - dest.mkdir(parents=True, exist_ok=True) - download_job = DownloadJob( - source=url, - dest=dest, - access_token=source.access_token, - ) - self._download_cache[download_job.source] = install_job # matches a download job to an install job - install_job.download_parts.add(download_job) - - # only start the jobs once install_job.download_parts is fully populated - for download_job in install_job.download_parts: - self._download_queue.submit_download_job( - download_job, - on_start=self._download_started_callback, - on_progress=self._download_progress_callback, - on_complete=self._download_complete_callback, - on_error=self._download_error_callback, - on_cancelled=self._download_cancelled_callback, - ) - + self._download_queue.submit_multifile_download(multifile_job) return install_job def _stat_size(self, path: Path) -> int: @@ -786,86 +775,59 @@ class ModelInstallService(ModelInstallServiceBase): # ------------------------------------------------------------------ # Callbacks are executed by the download queue in a separate thread # ------------------------------------------------------------------ - def _download_started_callback(self, download_job: DownloadJob) -> None: - self._logger.info(f"Model download started: {download_job.source}") + def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] + install_job = self._download_cache[download_job.id] install_job.status = InstallStatus.DOWNLOADING assert download_job.download_path - if install_job.local_path == install_job._install_tmpdir: - partial_path = download_job.download_path.relative_to(install_job._install_tmpdir) - dest_name = partial_path.parts[0] - install_job.local_path = install_job._install_tmpdir / dest_name + if install_job.local_path == install_job._install_tmpdir: # first time + install_job.local_path = download_job.download_path + install_job.total_bytes = download_job.total_bytes - # Update the total bytes count for remote sources. - if not install_job.total_bytes: - install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts) - - def _download_progress_callback(self, download_job: DownloadJob) -> None: + def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] + install_job = self._download_cache[download_job.id] if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() - self._cancel_download_parts(install_job) + self._download_queue.cancel_job(download_job) else: # update sizes - install_job.bytes = sum(x.bytes for x in install_job.download_parts) + install_job.bytes = sum(x.bytes for x in download_job.download_parts) self._signal_job_downloading(install_job) - def _download_complete_callback(self, download_job: DownloadJob) -> None: - self._logger.info(f"Model download complete: {download_job.source}") + def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache[download_job.source] - - # are there any more active jobs left in this task? - if install_job.downloading and all(x.complete for x in install_job.download_parts): - self._signal_job_downloads_done(install_job) - if install_job._do_install: - self._put_in_queue(install_job) + install_job = self._download_cache.pop(download_job.id) + self._signal_job_downloads_done(install_job) + self._put_in_queue(install_job) # this starts the installation and registration # Let other threads know that the number of downloads has changed - self._download_cache.pop(download_job.source, None) self._downloads_changed_event.set() - def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: + def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.source, None) + install_job = self._download_cache.pop(download_job.id) assert install_job is not None assert excp is not None install_job.set_error(excp) - self._logger.error( - f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}" - ) - self._cancel_download_parts(install_job) + self._download_queue.cancel_job(download_job) # Let other threads know that the number of downloads has changed self._downloads_changed_event.set() - def _download_cancelled_callback(self, download_job: DownloadJob) -> None: + def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None: with self._lock: - install_job = self._download_cache.pop(download_job.source, None) + install_job = self._download_cache.pop(download_job.id, None) if not install_job: return self._downloads_changed_event.set() - self._logger.warning(f"Model download canceled: {download_job.source}") # if install job has already registered an error, then do not replace its status with cancelled if not install_job.errored: install_job.cancel() - self._cancel_download_parts(install_job) # Let other threads know that the number of downloads has changed self._downloads_changed_event.set() - def _cancel_download_parts(self, install_job: ModelInstallJob) -> None: - # on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks - # do not lock here because it gets called within a locked context - for s in install_job.download_parts: - self._download_queue.cancel_job(s) - - if all(x.in_terminal_state for x in install_job.download_parts): - # When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources - self._put_in_queue(install_job) - # ------------------------------------------------------------------------------------------------ # Internal methods that put events on the event bus # ------------------------------------------------------------------------------------------------ @@ -877,6 +839,7 @@ class ModelInstallService(ModelInstallServiceBase): def _signal_job_downloading(self, job: ModelInstallJob) -> None: if self._event_bus: + assert job._download_job is not None parts: List[Dict[str, str | int]] = [ { "url": str(x.source), @@ -884,7 +847,7 @@ class ModelInstallService(ModelInstallServiceBase): "bytes": x.bytes, "total_bytes": x.total_bytes, } - for x in job.download_parts + for x in job._download_job.download_parts ] assert job.bytes is not None assert job.total_bytes is not None @@ -929,7 +892,13 @@ class ModelInstallService(ModelInstallServiceBase): self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) @staticmethod - def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: + def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]: + """ + Return a metadata fetcher appropriate for provided url. + + This used to be more useful, but the number of supported model + sources has been reduced to HuggingFace alone. + """ if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): return HuggingFaceMetadataFetch raise ValueError(f"Unsupported model source: '{url}'") diff --git a/invokeai/backend/model_manager/metadata/metadata_base.py b/invokeai/backend/model_manager/metadata/metadata_base.py index 4abf020538..f9f5335d17 100644 --- a/invokeai/backend/model_manager/metadata/metadata_base.py +++ b/invokeai/backend/model_manager/metadata/metadata_base.py @@ -40,6 +40,9 @@ class RemoteModelFile(BaseModel): size: Optional[int] = Field(description="The size of this file, in bytes", default=0) sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) + def __hash__(self) -> int: + return hash(str(self)) + class ModelMetadataBase(BaseModel): """Base class for model metadata information.""" diff --git a/tests/app/services/download/test_download_queue.py b/tests/app/services/download/test_download_queue.py index 393cd54a03..564d9c30a0 100644 --- a/tests/app/services/download/test_download_queue.py +++ b/tests/app/services/download/test_download_queue.py @@ -4,79 +4,33 @@ import re import time from contextlib import contextmanager from pathlib import Path -from typing import Generator, Optional +from typing import Any, Generator, Optional import pytest from pydantic.networks import AnyHttpUrl from requests.sessions import Session -from requests_testadapter import TestAdapter, TestSession +from requests_testadapter import TestAdapter from invokeai.app.services.config import get_config from invokeai.app.services.config.config_default import URLRegexTokenPair from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob -from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile +from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.test_nodes import TestEventService # Prevent pytest deprecation warnings -TestAdapter.__test__ = False # type: ignore - - -@pytest.fixture -def session() -> Session: - sess = TestSession() - for i in ["12345", "9999", "54321"]: - content = ( - b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) - ) # for pause tests, must make content large - sess.mount( - f"http://www.civitai.com/models/{i}", - TestAdapter( - content, - headers={ - "Content-Length": len(content), - "Content-Disposition": f'filename="mock{i}.safetensors"', - }, - ), - ) - - sess.mount( - "http://www.huggingface.co/foo.txt", - TestAdapter( - content, - headers={ - "Content-Length": len(content), - "Content-Disposition": 'filename="foo.safetensors"', - }, - ), - ) - - # here are some malformed URLs to test - # missing the content length - sess.mount( - "http://www.civitai.com/models/missing", - TestAdapter( - b"Missing content length", - headers={ - "Content-Disposition": 'filename="missing.txt"', - }, - ), - ) - # not found test - sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) - - return sess +TestAdapter.__test__ = False @pytest.mark.timeout(timeout=10, method="thread") -def test_basic_queue_download(tmp_path: Path, session: Session) -> None: +def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None: events = set() def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: events.add(job.status) queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() job = queue.download( @@ -92,6 +46,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None: queue.join() assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.download_path == tmp_path / "mock12345.safetensors" assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} @@ -99,9 +54,9 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None: @pytest.mark.timeout(timeout=10, method="thread") -def test_errors(tmp_path: Path, session: Session) -> None: +def test_errors(tmp_path: Path, mm2_session: Session) -> None: queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() @@ -121,10 +76,10 @@ def test_errors(tmp_path: Path, session: Session) -> None: @pytest.mark.timeout(timeout=10, method="thread") -def test_event_bus(tmp_path: Path, session: Session) -> None: +def test_event_bus(tmp_path: Path, mm2_session: Session) -> None: event_bus = TestEventService() - queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue.start() queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), @@ -157,9 +112,9 @@ def test_event_bus(tmp_path: Path, session: Session) -> None: @pytest.mark.timeout(timeout=10, method="thread") -def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: +def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None: queue = DownloadQueueService( - requests_session=session, + requests_session=mm2_session, ) queue.start() @@ -189,10 +144,10 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: @pytest.mark.timeout(timeout=10, method="thread") -def test_cancel(tmp_path: Path, session: Session) -> None: +def test_cancel(tmp_path: Path, mm2_session: Session) -> None: event_bus = TestEventService() - queue = DownloadQueueService(requests_session=session, event_bus=event_bus) + queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue.start() cancelled = False @@ -204,9 +159,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None: nonlocal cancelled cancelled = True - def handler(signum, frame): - raise TimeoutError("Join took too long to return") - job = queue.download( source=AnyHttpUrl("http://www.civitai.com/models/12345"), dest=tmp_path, @@ -223,14 +175,15 @@ def test_cancel(tmp_path: Path, session: Session) -> None: assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" queue.stop() + @pytest.mark.timeout(timeout=10, method="thread") def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: fetcher = HuggingFaceMetadataFetch(mm2_session) metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) events = set() def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: - print(f"bytes = {job.bytes}") events.add(job.status) queue = DownloadQueueService( @@ -251,6 +204,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" assert job.bytes > 0, "expected download bytes to be positive" assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert job.download_path == tmp_path / "sdxl-turbo" assert Path( tmp_path, "sdxl-turbo/model_index.json" ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" @@ -266,6 +220,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: fetcher = HuggingFaceMetadataFetch(mm2_session) metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) events = set() def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: @@ -289,13 +244,14 @@ def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: queue.join() assert job.status == DownloadJobStatus("error"), "expected job status to be errored" + assert job.error_type is not None assert "HTTPError(NOT FOUND)" in job.error_type assert DownloadJobStatus.ERROR in events queue.stop() @pytest.mark.timeout(timeout=10, method="thread") -def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None: +def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None: event_bus = TestEventService() queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) @@ -307,11 +263,9 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> nonlocal cancelled cancelled = True - def handler(signum, frame): - raise TimeoutError("Join took too long to return") - fetcher = HuggingFaceMetadataFetch(mm2_session) metadata = fetcher.from_id("stabilityai/sdxl-turbo") + assert isinstance(metadata, ModelMetadataWithFiles) job = queue.multifile_download( parts=metadata.download_urls(session=mm2_session), @@ -327,6 +281,29 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> assert "download_cancelled" in [x.event_name for x in events] queue.stop() + +def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None: + queue = DownloadQueueService( + requests_session=mm2_session, + ) + queue.start() + job = queue.multifile_download( + parts=[ + RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors")) + ], + dest=tmp_path, + ) + assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase" + queue.join() + + assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" + assert job.bytes > 0, "expected download bytes to be positive" + assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" + assert job.download_path == tmp_path / "mock12345.safetensors" + assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" + queue.stop() + + @contextmanager def clear_config() -> Generator[None, None, None]: try: @@ -335,11 +312,11 @@ def clear_config() -> Generator[None, None, None]: get_config.cache_clear() -def test_tokens(tmp_path: Path, session: Session): +def test_tokens(tmp_path: Path, mm2_session: Session): with clear_config(): config = get_config() config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")] - queue = DownloadQueueService(requests_session=session) + queue = DownloadQueueService(requests_session=mm2_session) queue.start() # this one has an access token assigned job1 = queue.download( diff --git a/tests/app/services/model_install/test_model_install.py b/tests/app/services/model_install/test_model_install.py index ba84455240..31d09d1029 100644 --- a/tests/app/services/model_install/test_model_install.py +++ b/tests/app/services/model_install/test_model_install.py @@ -286,14 +286,36 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con @pytest.mark.timeout(timeout=20, method="thread") -def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None: +def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: + # TODO: Test subfolder download source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) - job = mm2_installer.download_diffusers_model(source, tmp_path) - mm2_installer.wait_for_installs(timeout=5) - print(job.local_path) - assert job.status == InstallStatus.DOWNLOADS_DONE - assert (tmp_path / "sdxl-turbo").exists() - assert (tmp_path / "sdxl-turbo" / "model_index.json").exists() + + bus = mm2_installer.event_bus + store = mm2_installer.record_store + assert isinstance(bus, EventServiceBase) + assert store is not None + + job = mm2_installer.import_model(source) + job_list = mm2_installer.wait_for_installs(timeout=10) + assert len(job_list) == 1 + assert job.complete + assert job.config_out + + key = job.config_out.key + model_record = store.get_model(key) + assert (mm2_app_config.models_path / model_record.path).exists() + assert model_record.type == ModelType.Main + assert model_record.format == ModelFormat.Diffusers + + assert hasattr(bus, "events") # the dummyeventservice has this + assert len(bus.events) >= 3 + event_names = {x.event_name for x in bus.events} + assert event_names == { + "model_install_downloading", + "model_install_downloads_done", + "model_install_running", + "model_install_completed", + } def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: @@ -327,7 +349,6 @@ def test_other_error_during_install( assert job.error == "Test error" -# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test @pytest.mark.parametrize( "model_params", [ diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index 980f6ea17b..0301101a19 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -317,4 +317,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session: }, ), ) + + for i in ["12345", "9999", "54321"]: + content = ( + b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000) + ) # for pause tests, must make content large + sess.mount( + f"http://www.civitai.com/models/{i}", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": f'filename="mock{i}.safetensors"', + }, + ), + ) + + sess.mount( + "http://www.huggingface.co/foo.txt", + TestAdapter( + content, + headers={ + "Content-Length": len(content), + "Content-Disposition": 'filename="foo.safetensors"', + }, + ), + ) + + # here are some malformed URLs to test + # missing the content length + sess.mount( + "http://www.civitai.com/models/missing", + TestAdapter( + b"Missing content length", + headers={ + "Content-Disposition": 'filename="missing.txt"', + }, + ), + ) + # not found test + sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) + return sess