refactor model_install to work with refactored download queue

This commit is contained in:
Lincoln Stein 2024-05-13 22:49:15 -04:00
parent 287c679f7b
commit f29c406fed
8 changed files with 226 additions and 220 deletions

View File

@ -42,9 +42,13 @@ MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[E
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler] DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler] DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
class DownloadJobBase(BaseModel): class DownloadJobBase(BaseModel):
"""Base of classes to monitor and control downloads.""" """Base of classes to monitor and control downloads."""
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path") dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory") download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download") status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
@ -149,8 +153,6 @@ class DownloadJob(DownloadJobBase):
# required variables to be passed in on creation # required variables to be passed in on creation
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.") source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources") access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
# automatically assigned on creation
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
priority: int = Field(default=10, description="Queue priority; lower values are higher priority") priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
# set internally during download process # set internally during download process
@ -225,7 +227,7 @@ class DownloadQueueServiceBase(ABC):
@abstractmethod @abstractmethod
def multifile_download( def multifile_download(
self, self,
parts: Set[RemoteModelFile], parts: List[RemoteModelFile],
dest: Path, dest: Path,
access_token: Optional[str] = None, access_token: Optional[str] = None,
submit_job: bool = True, submit_job: bool = True,
@ -315,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def cancel_job(self, job: DownloadJob) -> None: def cancel_job(self, job: DownloadJobBase) -> None:
"""Cancel the job, clearing partial downloads and putting it into ERROR state.""" """Cancel the job, clearing partial downloads and putting it into ERROR state."""
pass pass
@ -325,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
pass pass
@abstractmethod @abstractmethod
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob: def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Wait until the indicated download job has reached a terminal state. """Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed, This will block until the indicated install job has completed,

View File

@ -113,9 +113,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
raise ServiceInactiveException( raise ServiceInactiveException(
"The download service is not currently accepting requests. Please call start() to initialize the service." "The download service is not currently accepting requests. Please call start() to initialize the service."
) )
with self._lock: job.id = self._next_id()
job.id = self._next_job_id
self._next_job_id += 1
job.set_callbacks( job.set_callbacks(
on_start=on_start, on_start=on_start,
on_progress=on_progress, on_progress=on_progress,
@ -161,16 +159,17 @@ class DownloadQueueService(DownloadQueueServiceBase):
def multifile_download( def multifile_download(
self, self,
parts: Set[RemoteModelFile], parts: List[RemoteModelFile],
dest: Path, dest: Path,
access_token: Optional[str] = None, access_token: Optional[str] = None,
submit_job: bool = True,
on_start: Optional[DownloadEventHandler] = None, on_start: Optional[DownloadEventHandler] = None,
on_progress: Optional[DownloadEventHandler] = None, on_progress: Optional[DownloadEventHandler] = None,
on_complete: Optional[DownloadEventHandler] = None, on_complete: Optional[DownloadEventHandler] = None,
on_cancelled: Optional[DownloadEventHandler] = None, on_cancelled: Optional[DownloadEventHandler] = None,
on_error: Optional[DownloadExceptionHandler] = None, on_error: Optional[DownloadExceptionHandler] = None,
) -> MultiFileDownloadJob: ) -> MultiFileDownloadJob:
mfdj = MultiFileDownloadJob(dest=dest) mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
mfdj.set_callbacks( mfdj.set_callbacks(
on_start=on_start, on_start=on_start,
on_progress=on_progress, on_progress=on_progress,
@ -190,6 +189,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
) )
mfdj.download_parts.add(job) mfdj.download_parts.add(job)
self._download_part2parent[job.source] = mfdj self._download_part2parent[job.source] = mfdj
if submit_job:
self.submit_multifile_download(mfdj) self.submit_multifile_download(mfdj)
return mfdj return mfdj
@ -208,6 +208,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
"""Wait for all jobs to complete.""" """Wait for all jobs to complete."""
self._queue.join() self._queue.join()
def _next_id(self) -> int:
with self._lock:
id = self._next_job_id
self._next_job_id += 1
return id
def list_jobs(self) -> List[DownloadJob]: def list_jobs(self) -> List[DownloadJob]:
"""List all the jobs.""" """List all the jobs."""
return list(self._jobs.values()) return list(self._jobs.values())
@ -229,7 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
except KeyError as excp: except KeyError as excp:
raise UnknownJobIDException("Unrecognized job") from excp raise UnknownJobIDException("Unrecognized job") from excp
def cancel_job(self, job: DownloadJob) -> None: def cancel_job(self, job: DownloadJobBase) -> None:
""" """
Cancel the indicated job. Cancel the indicated job.
@ -245,7 +251,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state: if not job.in_terminal_state:
self.cancel_job(job) self.cancel_job(job)
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob: def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
"""Block until the indicated job has reached terminal state, or when timeout limit reached.""" """Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time() start = time.time()
while not job.in_terminal_state: while not job.in_terminal_state:
@ -468,6 +474,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
if mf_job.waiting: if mf_job.waiting:
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts) mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
mf_job.status = DownloadJobStatus.RUNNING mf_job.status = DownloadJobStatus.RUNNING
assert download_job.download_path is not None
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
mf_job.download_path = (
mf_job.dest / path_relative_to_destdir.parts[0]
) # keep just the first component of the path
self._execute_cb(mf_job, "on_start") self._execute_cb(mf_job, "on_start")
def _mfd_progress(self, download_job: DownloadJob) -> None: def _mfd_progress(self, download_job: DownloadJob) -> None:

View File

@ -6,14 +6,14 @@ import traceback
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Set, Union from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from typing_extensions import Annotated from typing_extensions import Annotated
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import ModelRecordServiceBase from invokeai.app.services.model_records import ModelRecordServiceBase
@ -166,9 +166,6 @@ class ModelInstallJob(BaseModel):
source_metadata: Optional[AnyModelRepoMetadata] = Field( source_metadata: Optional[AnyModelRepoMetadata] = Field(
default=None, description="Metadata provided by the model source" default=None, description="Metadata provided by the model source"
) )
download_parts: Set[DownloadJob] = Field(
default_factory=set, description="Download jobs contributing to this install"
)
error: Optional[str] = Field( error: Optional[str] = Field(
default=None, description="On an error condition, this field will contain the text of the exception" default=None, description="On an error condition, this field will contain the text of the exception"
) )
@ -177,7 +174,7 @@ class ModelInstallJob(BaseModel):
) )
# internal flags and transitory settings # internal flags and transitory settings
_install_tmpdir: Optional[Path] = PrivateAttr(default=None) _install_tmpdir: Optional[Path] = PrivateAttr(default=None)
_do_install: Optional[bool] = PrivateAttr(default=True) _download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
_exception: Optional[Exception] = PrivateAttr(default=None) _exception: Optional[Exception] = PrivateAttr(default=None)
def set_error(self, e: Exception) -> None: def set_error(self, e: Exception) -> None:
@ -408,21 +405,6 @@ class ModelInstallServiceBase(ABC):
""" """
@abstractmethod
def download_diffusers_model(
self,
source: HFModelSource,
download_to: Path,
) -> ModelInstallJob:
"""
Download, but do not install, a diffusers model.
:param source: An HFModelSource object containing a repo_id
:param download_to: Path to directory that will contain the downloaded model.
Returns: a ModelInstallJob
"""
@abstractmethod @abstractmethod
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]: def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
"""Return the ModelInstallJob(s) corresponding to the provided source.""" """Return the ModelInstallJob(s) corresponding to the provided source."""

View File

@ -9,7 +9,7 @@ from pathlib import Path
from queue import Empty, Queue from queue import Empty, Queue
from shutil import copyfile, copytree, move, rmtree from shutil import copyfile, copytree, move, rmtree
from tempfile import mkdtemp from tempfile import mkdtemp
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Type, Union
import torch import torch
import yaml import yaml
@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl
from requests import Session from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
@ -89,7 +89,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._downloads_changed_event = threading.Event() self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event() self._install_completed_event = threading.Event()
self._download_queue = download_queue self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._download_cache: Dict[int, ModelInstallJob] = {}
self._running = False self._running = False
self._session = session self._session = session
self._install_thread: Optional[threading.Thread] = None self._install_thread: Optional[threading.Thread] = None
@ -249,9 +249,6 @@ class ModelInstallService(ModelInstallServiceBase):
self._install_jobs.append(install_job) self._install_jobs.append(install_job)
return install_job return install_job
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
return self._import_from_hf(source, download_path=download_to)
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102 def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
return self._install_jobs return self._install_jobs
@ -291,8 +288,9 @@ class ModelInstallService(ModelInstallServiceBase):
def cancel_job(self, job: ModelInstallJob) -> None: def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job.""" """Cancel the indicated job."""
job.cancel() job.cancel()
with self._lock: self._logger.warning(f"Cancelling {job.source}")
self._cancel_download_parts(job) if dj := job._download_job:
self._download_queue.cancel_job(dj)
def prune_jobs(self) -> None: def prune_jobs(self) -> None:
"""Prune all completed and errored jobs.""" """Prune all completed and errored jobs."""
@ -340,7 +338,7 @@ class ModelInstallService(ModelInstallServiceBase):
legacy_config_path = stanza.get("config") legacy_config_path = stanza.get("config")
if legacy_config_path: if legacy_config_path:
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir. # In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
legacy_config_path: Path = self._app_config.root_path / legacy_config_path legacy_config_path = self._app_config.root_path / legacy_config_path
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path): if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path) legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
config["config_path"] = str(legacy_config_path) config["config_path"] = str(legacy_config_path)
@ -476,16 +474,19 @@ class ModelInstallService(ModelInstallServiceBase):
job.config_out = self.record_store.get_model(key) job.config_out = self.record_store.get_model(key)
self._signal_job_completed(job) self._signal_job_completed(job)
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None: def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts): download_job = install_job._download_job
job.set_error( if download_job and any(
x.content_type is not None and "text/html" in x.content_type for x in download_job.download_parts
):
install_job.set_error(
InvalidModelConfigException( InvalidModelConfigException(
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download." f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
) )
) )
else: else:
job.set_error(excp) install_job.set_error(excp)
self._signal_job_errored(job) self._signal_job_errored(install_job)
# -------------------------------------------------------------------------------------------- # --------------------------------------------------------------------------------------------
# Internal functions that manage the models directory # Internal functions that manage the models directory
@ -511,7 +512,6 @@ class ModelInstallService(ModelInstallServiceBase):
This is typically only used during testing with a new DB or when using the memory DB, because those are the This is typically only used during testing with a new DB or when using the memory DB, because those are the
only situations in which we may have orphaned models in the models directory. only situations in which we may have orphaned models in the models directory.
""" """
installed_model_paths = { installed_model_paths = {
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models() (self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
} }
@ -648,7 +648,6 @@ class ModelInstallService(ModelInstallServiceBase):
self, self,
source: HFModelSource, source: HFModelSource,
config: Optional[Dict[str, Any]] = None, config: Optional[Dict[str, Any]] = None,
download_path: Optional[Path] = None,
) -> ModelInstallJob: ) -> ModelInstallJob:
# Add user's cached access token to HuggingFace requests # Add user's cached access token to HuggingFace requests
source.access_token = source.access_token or HfFolder.get_token() source.access_token = source.access_token or HfFolder.get_token()
@ -668,7 +667,6 @@ class ModelInstallService(ModelInstallServiceBase):
config=config, config=config,
remote_files=remote_files, remote_files=remote_files,
metadata=metadata, metadata=metadata,
download_path=download_path,
) )
def _import_from_url( def _import_from_url(
@ -704,14 +702,10 @@ class ModelInstallService(ModelInstallServiceBase):
remote_files: List[RemoteModelFile], remote_files: List[RemoteModelFile],
metadata: Optional[AnyModelRepoMetadata], metadata: Optional[AnyModelRepoMetadata],
config: Optional[Dict[str, Any]], config: Optional[Dict[str, Any]],
download_path: Optional[Path] = None, # if defined, download only - don't install!
) -> ModelInstallJob: ) -> ModelInstallJob:
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
# Currently the tmpdir isn't automatically removed at exit because it is
# being held in a daemon thread.
if len(remote_files) == 0: if len(remote_files) == 0:
raise ValueError(f"{source}: No downloadable files found") raise ValueError(f"{source}: No downloadable files found")
destdir = download_path or Path( destdir = Path(
mkdtemp( mkdtemp(
dir=self._app_config.models_path, dir=self._app_config.models_path,
prefix=TMPDIR_PREFIX, prefix=TMPDIR_PREFIX,
@ -726,6 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
bytes=0, bytes=0,
total_bytes=0, total_bytes=0,
) )
# remember the temporary directory for later removal
install_job._install_tmpdir = destdir
# In the event that there is a subfolder specified in the source, # In the event that there is a subfolder specified in the source,
# we need to remove it from the destination path in order to avoid # we need to remove it from the destination path in order to avoid
# creating unwanted subfolders # creating unwanted subfolders
@ -739,39 +736,31 @@ class ModelInstallService(ModelInstallServiceBase):
# we remember the path up to the top of the destdir so that it may be # we remember the path up to the top of the destdir so that it may be
# removed safely at the end of the install process. # removed safely at the end of the install process.
install_job._install_tmpdir = destdir install_job._install_tmpdir = destdir
install_job._do_install = download_path is None
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
files_string = "file" if len(remote_files) == 1 else "file" parts: List[RemoteModelFile] = []
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
for model_file in remote_files: for model_file in remote_files:
url = model_file.url assert install_job.total_bytes is not None
path = root / model_file.path.relative_to(subfolder) assert model_file.size is not None
self._logger.debug(f"Downloading {url} => {path}")
install_job.total_bytes += model_file.size install_job.total_bytes += model_file.size
assert hasattr(source, "access_token") parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
dest = destdir / path.parent multifile_job = self._download_queue.multifile_download(
dest.mkdir(parents=True, exist_ok=True) parts=parts,
download_job = DownloadJob( dest=destdir,
source=url,
dest=dest,
access_token=source.access_token, access_token=source.access_token,
) submit_job=False,
self._download_cache[download_job.source] = install_job # matches a download job to an install job
install_job.download_parts.add(download_job)
# only start the jobs once install_job.download_parts is fully populated
for download_job in install_job.download_parts:
self._download_queue.submit_download_job(
download_job,
on_start=self._download_started_callback, on_start=self._download_started_callback,
on_progress=self._download_progress_callback, on_progress=self._download_progress_callback,
on_complete=self._download_complete_callback, on_complete=self._download_complete_callback,
on_error=self._download_error_callback, on_error=self._download_error_callback,
on_cancelled=self._download_cancelled_callback, on_cancelled=self._download_cancelled_callback,
) )
self._download_cache[multifile_job.id] = install_job
install_job._download_job = multifile_job
files_string = "file" if len(remote_files) == 1 else "file"
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
self._logger.debug(f"remote_files={remote_files}")
self._download_queue.submit_multifile_download(multifile_job)
return install_job return install_job
def _stat_size(self, path: Path) -> int: def _stat_size(self, path: Path) -> int:
@ -786,86 +775,59 @@ class ModelInstallService(ModelInstallServiceBase):
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Callbacks are executed by the download queue in a separate thread # Callbacks are executed by the download queue in a separate thread
# ------------------------------------------------------------------ # ------------------------------------------------------------------
def _download_started_callback(self, download_job: DownloadJob) -> None: def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
self._logger.info(f"Model download started: {download_job.source}")
with self._lock: with self._lock:
install_job = self._download_cache[download_job.source] install_job = self._download_cache[download_job.id]
install_job.status = InstallStatus.DOWNLOADING install_job.status = InstallStatus.DOWNLOADING
assert download_job.download_path assert download_job.download_path
if install_job.local_path == install_job._install_tmpdir: if install_job.local_path == install_job._install_tmpdir: # first time
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir) install_job.local_path = download_job.download_path
dest_name = partial_path.parts[0] install_job.total_bytes = download_job.total_bytes
install_job.local_path = install_job._install_tmpdir / dest_name
# Update the total bytes count for remote sources. def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
if not install_job.total_bytes:
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
def _download_progress_callback(self, download_job: DownloadJob) -> None:
with self._lock: with self._lock:
install_job = self._download_cache[download_job.source] install_job = self._download_cache[download_job.id]
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel() if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
self._cancel_download_parts(install_job) self._download_queue.cancel_job(download_job)
else: else:
# update sizes # update sizes
install_job.bytes = sum(x.bytes for x in install_job.download_parts) install_job.bytes = sum(x.bytes for x in download_job.download_parts)
self._signal_job_downloading(install_job) self._signal_job_downloading(install_job)
def _download_complete_callback(self, download_job: DownloadJob) -> None: def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
self._logger.info(f"Model download complete: {download_job.source}")
with self._lock: with self._lock:
install_job = self._download_cache[download_job.source] install_job = self._download_cache.pop(download_job.id)
# are there any more active jobs left in this task?
if install_job.downloading and all(x.complete for x in install_job.download_parts):
self._signal_job_downloads_done(install_job) self._signal_job_downloads_done(install_job)
if install_job._do_install: self._put_in_queue(install_job) # this starts the installation and registration
self._put_in_queue(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._download_cache.pop(download_job.source, None)
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None: def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
with self._lock: with self._lock:
install_job = self._download_cache.pop(download_job.source, None) install_job = self._download_cache.pop(download_job.id)
assert install_job is not None assert install_job is not None
assert excp is not None assert excp is not None
install_job.set_error(excp) install_job.set_error(excp)
self._logger.error( self._download_queue.cancel_job(download_job)
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
)
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _download_cancelled_callback(self, download_job: DownloadJob) -> None: def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
with self._lock: with self._lock:
install_job = self._download_cache.pop(download_job.source, None) install_job = self._download_cache.pop(download_job.id, None)
if not install_job: if not install_job:
return return
self._downloads_changed_event.set() self._downloads_changed_event.set()
self._logger.warning(f"Model download canceled: {download_job.source}")
# if install job has already registered an error, then do not replace its status with cancelled # if install job has already registered an error, then do not replace its status with cancelled
if not install_job.errored: if not install_job.errored:
install_job.cancel() install_job.cancel()
self._cancel_download_parts(install_job)
# Let other threads know that the number of downloads has changed # Let other threads know that the number of downloads has changed
self._downloads_changed_event.set() self._downloads_changed_event.set()
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
# do not lock here because it gets called within a locked context
for s in install_job.download_parts:
self._download_queue.cancel_job(s)
if all(x.in_terminal_state for x in install_job.download_parts):
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
self._put_in_queue(install_job)
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
# Internal methods that put events on the event bus # Internal methods that put events on the event bus
# ------------------------------------------------------------------------------------------------ # ------------------------------------------------------------------------------------------------
@ -877,6 +839,7 @@ class ModelInstallService(ModelInstallServiceBase):
def _signal_job_downloading(self, job: ModelInstallJob) -> None: def _signal_job_downloading(self, job: ModelInstallJob) -> None:
if self._event_bus: if self._event_bus:
assert job._download_job is not None
parts: List[Dict[str, str | int]] = [ parts: List[Dict[str, str | int]] = [
{ {
"url": str(x.source), "url": str(x.source),
@ -884,7 +847,7 @@ class ModelInstallService(ModelInstallServiceBase):
"bytes": x.bytes, "bytes": x.bytes,
"total_bytes": x.total_bytes, "total_bytes": x.total_bytes,
} }
for x in job.download_parts for x in job._download_job.download_parts
] ]
assert job.bytes is not None assert job.bytes is not None
assert job.total_bytes is not None assert job.total_bytes is not None
@ -929,7 +892,13 @@ class ModelInstallService(ModelInstallServiceBase):
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id) self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
@staticmethod @staticmethod
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase: def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
"""
Return a metadata fetcher appropriate for provided url.
This used to be more useful, but the number of supported model
sources has been reduced to HuggingFace alone.
"""
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()): if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
return HuggingFaceMetadataFetch return HuggingFaceMetadataFetch
raise ValueError(f"Unsupported model source: '{url}'") raise ValueError(f"Unsupported model source: '{url}'")

View File

@ -40,6 +40,9 @@ class RemoteModelFile(BaseModel):
size: Optional[int] = Field(description="The size of this file, in bytes", default=0) size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None) sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
def __hash__(self) -> int:
return hash(str(self))
class ModelMetadataBase(BaseModel): class ModelMetadataBase(BaseModel):
"""Base class for model metadata information.""" """Base class for model metadata information."""

View File

@ -4,79 +4,33 @@ import re
import time import time
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Generator, Optional from typing import Any, Generator, Optional
import pytest import pytest
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter
from invokeai.app.services.config import get_config from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings # Prevent pytest deprecation warnings
TestAdapter.__test__ = False # type: ignore TestAdapter.__test__ = False
@pytest.fixture
def session() -> Session:
sess = TestSession()
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_basic_queue_download(tmp_path: Path, session: Session) -> None: def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
events = set() events = set()
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None: def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
events.add(job.status) events.add(job.status)
queue = DownloadQueueService( queue = DownloadQueueService(
requests_session=session, requests_session=mm2_session,
) )
queue.start() queue.start()
job = queue.download( job = queue.download(
@ -92,6 +46,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
queue.join() queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist" assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED} assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
@ -99,9 +54,9 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_errors(tmp_path: Path, session: Session) -> None: def test_errors(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService( queue = DownloadQueueService(
requests_session=session, requests_session=mm2_session,
) )
queue.start() queue.start()
@ -121,10 +76,10 @@ def test_errors(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_event_bus(tmp_path: Path, session: Session) -> None: def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService() event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus) queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start() queue.start()
queue.download( queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"), source=AnyHttpUrl("http://www.civitai.com/models/12345"),
@ -157,9 +112,9 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None: def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
queue = DownloadQueueService( queue = DownloadQueueService(
requests_session=session, requests_session=mm2_session,
) )
queue.start() queue.start()
@ -189,10 +144,10 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_cancel(tmp_path: Path, session: Session) -> None: def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
event_bus = TestEventService() event_bus = TestEventService()
queue = DownloadQueueService(requests_session=session, event_bus=event_bus) queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
queue.start() queue.start()
cancelled = False cancelled = False
@ -204,9 +159,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
nonlocal cancelled nonlocal cancelled
cancelled = True cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
job = queue.download( job = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"), source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path, dest=tmp_path,
@ -223,14 +175,15 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None: def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session) fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo") metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set() events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
print(f"bytes = {job.bytes}")
events.add(job.status) events.add(job.status)
queue = DownloadQueueService( queue = DownloadQueueService(
@ -251,6 +204,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed" assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive" assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes" assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "sdxl-turbo"
assert Path( assert Path(
tmp_path, "sdxl-turbo/model_index.json" tmp_path, "sdxl-turbo/model_index.json"
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist" ).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
@ -266,6 +220,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None: def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
fetcher = HuggingFaceMetadataFetch(mm2_session) fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo") metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
events = set() events = set()
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None: def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
@ -289,13 +244,14 @@ def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
queue.join() queue.join()
assert job.status == DownloadJobStatus("error"), "expected job status to be errored" assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
assert job.error_type is not None
assert "HTTPError(NOT FOUND)" in job.error_type assert "HTTPError(NOT FOUND)" in job.error_type
assert DownloadJobStatus.ERROR in events assert DownloadJobStatus.ERROR in events
queue.stop() queue.stop()
@pytest.mark.timeout(timeout=10, method="thread") @pytest.mark.timeout(timeout=10, method="thread")
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None: def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
event_bus = TestEventService() event_bus = TestEventService()
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus) queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
@ -307,11 +263,9 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
nonlocal cancelled nonlocal cancelled
cancelled = True cancelled = True
def handler(signum, frame):
raise TimeoutError("Join took too long to return")
fetcher = HuggingFaceMetadataFetch(mm2_session) fetcher = HuggingFaceMetadataFetch(mm2_session)
metadata = fetcher.from_id("stabilityai/sdxl-turbo") metadata = fetcher.from_id("stabilityai/sdxl-turbo")
assert isinstance(metadata, ModelMetadataWithFiles)
job = queue.multifile_download( job = queue.multifile_download(
parts=metadata.download_urls(session=mm2_session), parts=metadata.download_urls(session=mm2_session),
@ -327,6 +281,29 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
assert "download_cancelled" in [x.event_name for x in events] assert "download_cancelled" in [x.event_name for x in events]
queue.stop() queue.stop()
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
queue = DownloadQueueService(
requests_session=mm2_session,
)
queue.start()
job = queue.multifile_download(
parts=[
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
],
dest=tmp_path,
)
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
queue.join()
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
assert job.bytes > 0, "expected download bytes to be positive"
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
assert job.download_path == tmp_path / "mock12345.safetensors"
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
queue.stop()
@contextmanager @contextmanager
def clear_config() -> Generator[None, None, None]: def clear_config() -> Generator[None, None, None]:
try: try:
@ -335,11 +312,11 @@ def clear_config() -> Generator[None, None, None]:
get_config.cache_clear() get_config.cache_clear()
def test_tokens(tmp_path: Path, session: Session): def test_tokens(tmp_path: Path, mm2_session: Session):
with clear_config(): with clear_config():
config = get_config() config = get_config()
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")] config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
queue = DownloadQueueService(requests_session=session) queue = DownloadQueueService(requests_session=mm2_session)
queue.start() queue.start()
# this one has an access token assigned # this one has an access token assigned
job1 = queue.download( job1 = queue.download(

View File

@ -286,14 +286,36 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
@pytest.mark.timeout(timeout=20, method="thread") @pytest.mark.timeout(timeout=20, method="thread")
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None: def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
# TODO: Test subfolder download
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default) source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
job = mm2_installer.download_diffusers_model(source, tmp_path)
mm2_installer.wait_for_installs(timeout=5) bus = mm2_installer.event_bus
print(job.local_path) store = mm2_installer.record_store
assert job.status == InstallStatus.DOWNLOADS_DONE assert isinstance(bus, EventServiceBase)
assert (tmp_path / "sdxl-turbo").exists() assert store is not None
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
job = mm2_installer.import_model(source)
job_list = mm2_installer.wait_for_installs(timeout=10)
assert len(job_list) == 1
assert job.complete
assert job.config_out
key = job.config_out.key
model_record = store.get_model(key)
assert (mm2_app_config.models_path / model_record.path).exists()
assert model_record.type == ModelType.Main
assert model_record.format == ModelFormat.Diffusers
assert hasattr(bus, "events") # the dummyeventservice has this
assert len(bus.events) >= 3
event_names = {x.event_name for x in bus.events}
assert event_names == {
"model_install_downloading",
"model_install_downloads_done",
"model_install_running",
"model_install_completed",
}
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None: def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
@ -327,7 +349,6 @@ def test_other_error_during_install(
assert job.error == "Test error" assert job.error == "Test error"
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model_params", "model_params",
[ [

View File

@ -317,4 +317,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
}, },
), ),
) )
for i in ["12345", "9999", "54321"]:
content = (
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
) # for pause tests, must make content large
sess.mount(
f"http://www.civitai.com/models/{i}",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": f'filename="mock{i}.safetensors"',
},
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
"http://www.civitai.com/models/missing",
TestAdapter(
b"Missing content length",
headers={
"Content-Disposition": 'filename="missing.txt"',
},
),
)
# not found test
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
return sess return sess