mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor model_install to work with refactored download queue
This commit is contained in:
parent
287c679f7b
commit
f29c406fed
@ -42,9 +42,13 @@ MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[E
|
||||
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||
|
||||
|
||||
class DownloadJobBase(BaseModel):
|
||||
"""Base of classes to monitor and control downloads."""
|
||||
|
||||
# automatically assigned on creation
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
|
||||
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||
@ -149,8 +153,6 @@ class DownloadJob(DownloadJobBase):
|
||||
# required variables to be passed in on creation
|
||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||
# automatically assigned on creation
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
|
||||
# set internally during download process
|
||||
@ -225,7 +227,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def multifile_download(
|
||||
self,
|
||||
parts: Set[RemoteModelFile],
|
||||
parts: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
@ -315,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@ -325,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||
"""Wait until the indicated download job has reached a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
|
@ -113,18 +113,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
raise ServiceInactiveException(
|
||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||
)
|
||||
with self._lock:
|
||||
job.id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
job.id = self._next_id()
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
|
||||
def download(
|
||||
self,
|
||||
@ -161,16 +159,17 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
def multifile_download(
|
||||
self,
|
||||
parts: Set[RemoteModelFile],
|
||||
parts: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> MultiFileDownloadJob:
|
||||
mfdj = MultiFileDownloadJob(dest=dest)
|
||||
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
|
||||
mfdj.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
@ -190,7 +189,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
)
|
||||
mfdj.download_parts.add(job)
|
||||
self._download_part2parent[job.source] = mfdj
|
||||
self.submit_multifile_download(mfdj)
|
||||
if submit_job:
|
||||
self.submit_multifile_download(mfdj)
|
||||
return mfdj
|
||||
|
||||
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||
@ -208,6 +208,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Wait for all jobs to complete."""
|
||||
self._queue.join()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
with self._lock:
|
||||
id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
return id
|
||||
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""List all the jobs."""
|
||||
return list(self._jobs.values())
|
||||
@ -229,7 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
@ -245,7 +251,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if not job.in_terminal_state:
|
||||
self.cancel_job(job)
|
||||
|
||||
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
@ -468,6 +474,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if mf_job.waiting:
|
||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
mf_job.status = DownloadJobStatus.RUNNING
|
||||
assert download_job.download_path is not None
|
||||
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
|
||||
mf_job.download_path = (
|
||||
mf_job.dest / path_relative_to_destdir.parts[0]
|
||||
) # keep just the first component of the path
|
||||
self._execute_cb(mf_job, "on_start")
|
||||
|
||||
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||
|
@ -6,14 +6,14 @@ import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
@ -166,9 +166,6 @@ class ModelInstallJob(BaseModel):
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
@ -177,7 +174,7 @@ class ModelInstallJob(BaseModel):
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_do_install: Optional[bool] = PrivateAttr(default=True)
|
||||
_download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
@ -408,21 +405,6 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_diffusers_model(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
download_to: Path,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Download, but do not install, a diffusers model.
|
||||
|
||||
:param source: An HFModelSource object containing a repo_id
|
||||
:param download_to: Path to directory that will contain the downloaded model.
|
||||
|
||||
Returns: a ModelInstallJob
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
@ -9,7 +9,7 @@ from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
@ -89,7 +89,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._downloads_changed_event = threading.Event()
|
||||
self._install_completed_event = threading.Event()
|
||||
self._download_queue = download_queue
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._download_cache: Dict[int, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._install_thread: Optional[threading.Thread] = None
|
||||
@ -249,9 +249,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_jobs.append(install_job)
|
||||
return install_job
|
||||
|
||||
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
|
||||
return self._import_from_hf(source, download_path=download_to)
|
||||
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
@ -291,8 +288,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
job.cancel()
|
||||
with self._lock:
|
||||
self._cancel_download_parts(job)
|
||||
self._logger.warning(f"Cancelling {job.source}")
|
||||
if dj := job._download_job:
|
||||
self._download_queue.cancel_job(dj)
|
||||
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune all completed and errored jobs."""
|
||||
@ -340,7 +338,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
@ -476,16 +474,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
||||
job.set_error(
|
||||
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
||||
download_job = install_job._download_job
|
||||
if download_job and any(
|
||||
x.content_type is not None and "text/html" in x.content_type for x in download_job.download_parts
|
||||
):
|
||||
install_job.set_error(
|
||||
InvalidModelConfigException(
|
||||
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
)
|
||||
)
|
||||
else:
|
||||
job.set_error(excp)
|
||||
self._signal_job_errored(job)
|
||||
install_job.set_error(excp)
|
||||
self._signal_job_errored(install_job)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the models directory
|
||||
@ -511,7 +512,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
||||
only situations in which we may have orphaned models in the models directory.
|
||||
"""
|
||||
|
||||
installed_model_paths = {
|
||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||
}
|
||||
@ -648,7 +648,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
download_path: Optional[Path] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
@ -668,7 +667,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config=config,
|
||||
remote_files=remote_files,
|
||||
metadata=metadata,
|
||||
download_path=download_path,
|
||||
)
|
||||
|
||||
def _import_from_url(
|
||||
@ -704,14 +702,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
download_path: Optional[Path] = None, # if defined, download only - don't install!
|
||||
) -> ModelInstallJob:
|
||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||
# being held in a daemon thread.
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
destdir = download_path or Path(
|
||||
destdir = Path(
|
||||
mkdtemp(
|
||||
dir=self._app_config.models_path,
|
||||
prefix=TMPDIR_PREFIX,
|
||||
@ -726,6 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
bytes=0,
|
||||
total_bytes=0,
|
||||
)
|
||||
# remember the temporary directory for later removal
|
||||
install_job._install_tmpdir = destdir
|
||||
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
@ -739,39 +736,31 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# we remember the path up to the top of the destdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = destdir
|
||||
install_job._do_install = download_path is None
|
||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert install_job.total_bytes is not None
|
||||
assert model_file.size is not None
|
||||
install_job.total_bytes += model_file.size
|
||||
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
|
||||
multifile_job = self._download_queue.multifile_download(
|
||||
parts=parts,
|
||||
dest=destdir,
|
||||
access_token=source.access_token,
|
||||
submit_job=False,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
self._download_cache[multifile_job.id] = install_job
|
||||
install_job._download_job = multifile_job
|
||||
|
||||
files_string = "file" if len(remote_files) == 1 else "file"
|
||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
||||
self._logger.debug(f"remote_files={remote_files}")
|
||||
for model_file in remote_files:
|
||||
url = model_file.url
|
||||
path = root / model_file.path.relative_to(subfolder)
|
||||
self._logger.debug(f"Downloading {url} => {path}")
|
||||
install_job.total_bytes += model_file.size
|
||||
assert hasattr(source, "access_token")
|
||||
dest = destdir / path.parent
|
||||
dest.mkdir(parents=True, exist_ok=True)
|
||||
download_job = DownloadJob(
|
||||
source=url,
|
||||
dest=dest,
|
||||
access_token=source.access_token,
|
||||
)
|
||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||
install_job.download_parts.add(download_job)
|
||||
|
||||
# only start the jobs once install_job.download_parts is fully populated
|
||||
for download_job in install_job.download_parts:
|
||||
self._download_queue.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
self._download_queue.submit_multifile_download(multifile_job)
|
||||
return install_job
|
||||
|
||||
def _stat_size(self, path: Path) -> int:
|
||||
@ -786,86 +775,59 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# ------------------------------------------------------------------
|
||||
# Callbacks are executed by the download queue in a separate thread
|
||||
# ------------------------------------------------------------------
|
||||
def _download_started_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Model download started: {download_job.source}")
|
||||
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
install_job = self._download_cache[download_job.id]
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
|
||||
assert download_job.download_path
|
||||
if install_job.local_path == install_job._install_tmpdir:
|
||||
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
|
||||
dest_name = partial_path.parts[0]
|
||||
install_job.local_path = install_job._install_tmpdir / dest_name
|
||||
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||
install_job.local_path = download_job.download_path
|
||||
install_job.total_bytes = download_job.total_bytes
|
||||
|
||||
# Update the total bytes count for remote sources.
|
||||
if not install_job.total_bytes:
|
||||
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
|
||||
|
||||
def _download_progress_callback(self, download_job: DownloadJob) -> None:
|
||||
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
install_job = self._download_cache[download_job.id]
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
self._download_queue.cancel_job(download_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
|
||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Model download complete: {download_job.source}")
|
||||
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
self._signal_job_downloads_done(install_job)
|
||||
if install_job._do_install:
|
||||
self._put_in_queue(install_job)
|
||||
install_job = self._download_cache.pop(download_job.id)
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job) # this starts the installation and registration
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
install_job = self._download_cache.pop(download_job.id)
|
||||
assert install_job is not None
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._logger.error(
|
||||
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||
)
|
||||
self._cancel_download_parts(install_job)
|
||||
self._download_queue.cancel_job(download_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
|
||||
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
install_job = self._download_cache.pop(download_job.id, None)
|
||||
if not install_job:
|
||||
return
|
||||
self._downloads_changed_event.set()
|
||||
self._logger.warning(f"Model download canceled: {download_job.source}")
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
|
||||
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
|
||||
# do not lock here because it gets called within a locked context
|
||||
for s in install_job.download_parts:
|
||||
self._download_queue.cancel_job(s)
|
||||
|
||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||
self._put_in_queue(install_job)
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
@ -877,6 +839,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
assert job._download_job is not None
|
||||
parts: List[Dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
@ -884,7 +847,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
for x in job._download_job.download_parts
|
||||
]
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
@ -929,7 +892,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
|
||||
"""
|
||||
Return a metadata fetcher appropriate for provided url.
|
||||
|
||||
This used to be more useful, but the number of supported model
|
||||
sources has been reduced to HuggingFace alone.
|
||||
"""
|
||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
@ -40,6 +40,9 @@ class RemoteModelFile(BaseModel):
|
||||
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
@ -4,79 +4,33 @@ import re
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator, Optional
|
||||
from typing import Any, Generator, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
from requests_testadapter import TestAdapter
|
||||
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
||||
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile
|
||||
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
TestAdapter.__test__ = False # type: ignore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session() -> Session:
|
||||
sess = TestSession()
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
sess.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
sess.mount(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": 'filename="foo.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found test
|
||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
return sess
|
||||
TestAdapter.__test__ = False
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.download(
|
||||
@ -92,6 +46,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
|
||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||
@ -99,9 +54,9 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
@ -121,10 +76,10 @@ def test_errors(tmp_path: Path, session: Session) -> None:
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
@ -157,9 +112,9 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=session,
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
|
||||
@ -189,10 +144,10 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
queue.start()
|
||||
|
||||
cancelled = False
|
||||
@ -204,9 +159,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Join took too long to return")
|
||||
|
||||
job = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
@ -223,14 +175,15 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
print(f"bytes = {job.bytes}")
|
||||
events.add(job.status)
|
||||
|
||||
queue = DownloadQueueService(
|
||||
@ -251,6 +204,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "sdxl-turbo"
|
||||
assert Path(
|
||||
tmp_path, "sdxl-turbo/model_index.json"
|
||||
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||
@ -266,6 +220,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
events = set()
|
||||
|
||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
@ -289,13 +244,14 @@ def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
||||
assert job.error_type is not None
|
||||
assert "HTTPError(NOT FOUND)" in job.error_type
|
||||
assert DownloadJobStatus.ERROR in events
|
||||
queue.stop()
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=10, method="thread")
|
||||
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
|
||||
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
|
||||
event_bus = TestEventService()
|
||||
|
||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||
@ -307,11 +263,9 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
|
||||
nonlocal cancelled
|
||||
cancelled = True
|
||||
|
||||
def handler(signum, frame):
|
||||
raise TimeoutError("Join took too long to return")
|
||||
|
||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||
|
||||
job = queue.multifile_download(
|
||||
parts=metadata.download_urls(session=mm2_session),
|
||||
@ -327,6 +281,29 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
|
||||
assert "download_cancelled" in [x.event_name for x in events]
|
||||
queue.stop()
|
||||
|
||||
|
||||
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
|
||||
queue = DownloadQueueService(
|
||||
requests_session=mm2_session,
|
||||
)
|
||||
queue.start()
|
||||
job = queue.multifile_download(
|
||||
parts=[
|
||||
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
|
||||
],
|
||||
dest=tmp_path,
|
||||
)
|
||||
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||
queue.join()
|
||||
|
||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||
assert job.bytes > 0, "expected download bytes to be positive"
|
||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||
queue.stop()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
@ -335,11 +312,11 @@ def clear_config() -> Generator[None, None, None]:
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def test_tokens(tmp_path: Path, session: Session):
|
||||
def test_tokens(tmp_path: Path, mm2_session: Session):
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||
queue = DownloadQueueService(requests_session=session)
|
||||
queue = DownloadQueueService(requests_session=mm2_session)
|
||||
queue.start()
|
||||
# this one has an access token assigned
|
||||
job1 = queue.download(
|
||||
|
@ -286,14 +286,36 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
||||
|
||||
|
||||
@pytest.mark.timeout(timeout=20, method="thread")
|
||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
|
||||
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
# TODO: Test subfolder download
|
||||
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||
job = mm2_installer.download_diffusers_model(source, tmp_path)
|
||||
mm2_installer.wait_for_installs(timeout=5)
|
||||
print(job.local_path)
|
||||
assert job.status == InstallStatus.DOWNLOADS_DONE
|
||||
assert (tmp_path / "sdxl-turbo").exists()
|
||||
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
|
||||
|
||||
bus = mm2_installer.event_bus
|
||||
store = mm2_installer.record_store
|
||||
assert isinstance(bus, EventServiceBase)
|
||||
assert store is not None
|
||||
|
||||
job = mm2_installer.import_model(source)
|
||||
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||
assert len(job_list) == 1
|
||||
assert job.complete
|
||||
assert job.config_out
|
||||
|
||||
key = job.config_out.key
|
||||
model_record = store.get_model(key)
|
||||
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||
assert model_record.type == ModelType.Main
|
||||
assert model_record.format == ModelFormat.Diffusers
|
||||
|
||||
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||
assert len(bus.events) >= 3
|
||||
event_names = {x.event_name for x in bus.events}
|
||||
assert event_names == {
|
||||
"model_install_downloading",
|
||||
"model_install_downloads_done",
|
||||
"model_install_running",
|
||||
"model_install_completed",
|
||||
}
|
||||
|
||||
|
||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||
@ -327,7 +349,6 @@ def test_other_error_during_install(
|
||||
assert job.error == "Test error"
|
||||
|
||||
|
||||
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
|
||||
@pytest.mark.parametrize(
|
||||
"model_params",
|
||||
[
|
||||
|
@ -317,4 +317,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
for i in ["12345", "9999", "54321"]:
|
||||
content = (
|
||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||
) # for pause tests, must make content large
|
||||
sess.mount(
|
||||
f"http://www.civitai.com/models/{i}",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
sess.mount(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": 'filename="foo.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
"http://www.civitai.com/models/missing",
|
||||
TestAdapter(
|
||||
b"Missing content length",
|
||||
headers={
|
||||
"Content-Disposition": 'filename="missing.txt"',
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found test
|
||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
return sess
|
||||
|
Loading…
Reference in New Issue
Block a user