mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor model_install to work with refactored download queue
This commit is contained in:
@ -42,9 +42,13 @@ MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[E
|
||||
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||
|
||||
|
||||
class DownloadJobBase(BaseModel):
|
||||
"""Base of classes to monitor and control downloads."""
|
||||
|
||||
# automatically assigned on creation
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
|
||||
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||
@ -149,8 +153,6 @@ class DownloadJob(DownloadJobBase):
|
||||
# required variables to be passed in on creation
|
||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||
# automatically assigned on creation
|
||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||
|
||||
# set internally during download process
|
||||
@ -225,7 +227,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def multifile_download(
|
||||
self,
|
||||
parts: Set[RemoteModelFile],
|
||||
parts: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
@ -315,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||
pass
|
||||
|
||||
@ -325,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||
"""Wait until the indicated download job has reached a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
|
@ -113,18 +113,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
raise ServiceInactiveException(
|
||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||
)
|
||||
with self._lock:
|
||||
job.id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
job.id = self._next_id()
|
||||
job.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
on_complete=on_complete,
|
||||
on_cancelled=on_cancelled,
|
||||
on_error=on_error,
|
||||
)
|
||||
self._jobs[job.id] = job
|
||||
self._queue.put(job)
|
||||
|
||||
def download(
|
||||
self,
|
||||
@ -161,16 +159,17 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
def multifile_download(
|
||||
self,
|
||||
parts: Set[RemoteModelFile],
|
||||
parts: List[RemoteModelFile],
|
||||
dest: Path,
|
||||
access_token: Optional[str] = None,
|
||||
submit_job: bool = True,
|
||||
on_start: Optional[DownloadEventHandler] = None,
|
||||
on_progress: Optional[DownloadEventHandler] = None,
|
||||
on_complete: Optional[DownloadEventHandler] = None,
|
||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||
on_error: Optional[DownloadExceptionHandler] = None,
|
||||
) -> MultiFileDownloadJob:
|
||||
mfdj = MultiFileDownloadJob(dest=dest)
|
||||
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
|
||||
mfdj.set_callbacks(
|
||||
on_start=on_start,
|
||||
on_progress=on_progress,
|
||||
@ -190,7 +189,8 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
)
|
||||
mfdj.download_parts.add(job)
|
||||
self._download_part2parent[job.source] = mfdj
|
||||
self.submit_multifile_download(mfdj)
|
||||
if submit_job:
|
||||
self.submit_multifile_download(mfdj)
|
||||
return mfdj
|
||||
|
||||
def submit_multifile_download(self, job: MultiFileDownloadJob) -> None:
|
||||
@ -208,6 +208,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
"""Wait for all jobs to complete."""
|
||||
self._queue.join()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
with self._lock:
|
||||
id = self._next_job_id
|
||||
self._next_job_id += 1
|
||||
return id
|
||||
|
||||
def list_jobs(self) -> List[DownloadJob]:
|
||||
"""List all the jobs."""
|
||||
return list(self._jobs.values())
|
||||
@ -229,7 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
except KeyError as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
|
||||
def cancel_job(self, job: DownloadJob) -> None:
|
||||
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||
"""
|
||||
Cancel the indicated job.
|
||||
|
||||
@ -245,7 +251,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if not job.in_terminal_state:
|
||||
self.cancel_job(job)
|
||||
|
||||
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
@ -468,6 +474,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if mf_job.waiting:
|
||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||
mf_job.status = DownloadJobStatus.RUNNING
|
||||
assert download_job.download_path is not None
|
||||
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
|
||||
mf_job.download_path = (
|
||||
mf_job.dest / path_relative_to_destdir.parts[0]
|
||||
) # keep just the first component of the path
|
||||
self._execute_cb(mf_job, "on_start")
|
||||
|
||||
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||
|
@ -6,14 +6,14 @@ import traceback
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||
@ -166,9 +166,6 @@ class ModelInstallJob(BaseModel):
|
||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||
default=None, description="Metadata provided by the model source"
|
||||
)
|
||||
download_parts: Set[DownloadJob] = Field(
|
||||
default_factory=set, description="Download jobs contributing to this install"
|
||||
)
|
||||
error: Optional[str] = Field(
|
||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||
)
|
||||
@ -177,7 +174,7 @@ class ModelInstallJob(BaseModel):
|
||||
)
|
||||
# internal flags and transitory settings
|
||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||
_do_install: Optional[bool] = PrivateAttr(default=True)
|
||||
_download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
|
||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||
|
||||
def set_error(self, e: Exception) -> None:
|
||||
@ -408,21 +405,6 @@ class ModelInstallServiceBase(ABC):
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def download_diffusers_model(
|
||||
self,
|
||||
source: HFModelSource,
|
||||
download_to: Path,
|
||||
) -> ModelInstallJob:
|
||||
"""
|
||||
Download, but do not install, a diffusers model.
|
||||
|
||||
:param source: An HFModelSource object containing a repo_id
|
||||
:param download_to: Path to directory that will contain the downloaded model.
|
||||
|
||||
Returns: a ModelInstallJob
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||
|
@ -9,7 +9,7 @@ from pathlib import Path
|
||||
from queue import Empty, Queue
|
||||
from shutil import copyfile, copytree, move, rmtree
|
||||
from tempfile import mkdtemp
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||
@ -89,7 +89,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._downloads_changed_event = threading.Event()
|
||||
self._install_completed_event = threading.Event()
|
||||
self._download_queue = download_queue
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._download_cache: Dict[int, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
self._session = session
|
||||
self._install_thread: Optional[threading.Thread] = None
|
||||
@ -249,9 +249,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._install_jobs.append(install_job)
|
||||
return install_job
|
||||
|
||||
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
|
||||
return self._import_from_hf(source, download_path=download_to)
|
||||
|
||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||
return self._install_jobs
|
||||
|
||||
@ -291,8 +288,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
job.cancel()
|
||||
with self._lock:
|
||||
self._cancel_download_parts(job)
|
||||
self._logger.warning(f"Cancelling {job.source}")
|
||||
if dj := job._download_job:
|
||||
self._download_queue.cancel_job(dj)
|
||||
|
||||
def prune_jobs(self) -> None:
|
||||
"""Prune all completed and errored jobs."""
|
||||
@ -340,7 +338,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
legacy_config_path = stanza.get("config")
|
||||
if legacy_config_path:
|
||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
||||
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||
config["config_path"] = str(legacy_config_path)
|
||||
@ -476,16 +474,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
job.config_out = self.record_store.get_model(key)
|
||||
self._signal_job_completed(job)
|
||||
|
||||
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
|
||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
||||
job.set_error(
|
||||
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
||||
download_job = install_job._download_job
|
||||
if download_job and any(
|
||||
x.content_type is not None and "text/html" in x.content_type for x in download_job.download_parts
|
||||
):
|
||||
install_job.set_error(
|
||||
InvalidModelConfigException(
|
||||
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||
)
|
||||
)
|
||||
else:
|
||||
job.set_error(excp)
|
||||
self._signal_job_errored(job)
|
||||
install_job.set_error(excp)
|
||||
self._signal_job_errored(install_job)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the models directory
|
||||
@ -511,7 +512,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
||||
only situations in which we may have orphaned models in the models directory.
|
||||
"""
|
||||
|
||||
installed_model_paths = {
|
||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||
}
|
||||
@ -648,7 +648,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self,
|
||||
source: HFModelSource,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
download_path: Optional[Path] = None,
|
||||
) -> ModelInstallJob:
|
||||
# Add user's cached access token to HuggingFace requests
|
||||
source.access_token = source.access_token or HfFolder.get_token()
|
||||
@ -668,7 +667,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
config=config,
|
||||
remote_files=remote_files,
|
||||
metadata=metadata,
|
||||
download_path=download_path,
|
||||
)
|
||||
|
||||
def _import_from_url(
|
||||
@ -704,14 +702,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
remote_files: List[RemoteModelFile],
|
||||
metadata: Optional[AnyModelRepoMetadata],
|
||||
config: Optional[Dict[str, Any]],
|
||||
download_path: Optional[Path] = None, # if defined, download only - don't install!
|
||||
) -> ModelInstallJob:
|
||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
||||
# being held in a daemon thread.
|
||||
if len(remote_files) == 0:
|
||||
raise ValueError(f"{source}: No downloadable files found")
|
||||
destdir = download_path or Path(
|
||||
destdir = Path(
|
||||
mkdtemp(
|
||||
dir=self._app_config.models_path,
|
||||
prefix=TMPDIR_PREFIX,
|
||||
@ -726,6 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
bytes=0,
|
||||
total_bytes=0,
|
||||
)
|
||||
# remember the temporary directory for later removal
|
||||
install_job._install_tmpdir = destdir
|
||||
|
||||
# In the event that there is a subfolder specified in the source,
|
||||
# we need to remove it from the destination path in order to avoid
|
||||
# creating unwanted subfolders
|
||||
@ -739,39 +736,31 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# we remember the path up to the top of the destdir so that it may be
|
||||
# removed safely at the end of the install process.
|
||||
install_job._install_tmpdir = destdir
|
||||
install_job._do_install = download_path is None
|
||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
||||
|
||||
parts: List[RemoteModelFile] = []
|
||||
for model_file in remote_files:
|
||||
assert install_job.total_bytes is not None
|
||||
assert model_file.size is not None
|
||||
install_job.total_bytes += model_file.size
|
||||
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
|
||||
multifile_job = self._download_queue.multifile_download(
|
||||
parts=parts,
|
||||
dest=destdir,
|
||||
access_token=source.access_token,
|
||||
submit_job=False,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
self._download_cache[multifile_job.id] = install_job
|
||||
install_job._download_job = multifile_job
|
||||
|
||||
files_string = "file" if len(remote_files) == 1 else "file"
|
||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
||||
self._logger.debug(f"remote_files={remote_files}")
|
||||
for model_file in remote_files:
|
||||
url = model_file.url
|
||||
path = root / model_file.path.relative_to(subfolder)
|
||||
self._logger.debug(f"Downloading {url} => {path}")
|
||||
install_job.total_bytes += model_file.size
|
||||
assert hasattr(source, "access_token")
|
||||
dest = destdir / path.parent
|
||||
dest.mkdir(parents=True, exist_ok=True)
|
||||
download_job = DownloadJob(
|
||||
source=url,
|
||||
dest=dest,
|
||||
access_token=source.access_token,
|
||||
)
|
||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
||||
install_job.download_parts.add(download_job)
|
||||
|
||||
# only start the jobs once install_job.download_parts is fully populated
|
||||
for download_job in install_job.download_parts:
|
||||
self._download_queue.submit_download_job(
|
||||
download_job,
|
||||
on_start=self._download_started_callback,
|
||||
on_progress=self._download_progress_callback,
|
||||
on_complete=self._download_complete_callback,
|
||||
on_error=self._download_error_callback,
|
||||
on_cancelled=self._download_cancelled_callback,
|
||||
)
|
||||
|
||||
self._download_queue.submit_multifile_download(multifile_job)
|
||||
return install_job
|
||||
|
||||
def _stat_size(self, path: Path) -> int:
|
||||
@ -786,86 +775,59 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# ------------------------------------------------------------------
|
||||
# Callbacks are executed by the download queue in a separate thread
|
||||
# ------------------------------------------------------------------
|
||||
def _download_started_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Model download started: {download_job.source}")
|
||||
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
install_job = self._download_cache[download_job.id]
|
||||
install_job.status = InstallStatus.DOWNLOADING
|
||||
|
||||
assert download_job.download_path
|
||||
if install_job.local_path == install_job._install_tmpdir:
|
||||
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
|
||||
dest_name = partial_path.parts[0]
|
||||
install_job.local_path = install_job._install_tmpdir / dest_name
|
||||
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||
install_job.local_path = download_job.download_path
|
||||
install_job.total_bytes = download_job.total_bytes
|
||||
|
||||
# Update the total bytes count for remote sources.
|
||||
if not install_job.total_bytes:
|
||||
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
|
||||
|
||||
def _download_progress_callback(self, download_job: DownloadJob) -> None:
|
||||
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
install_job = self._download_cache[download_job.id]
|
||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
self._download_queue.cancel_job(download_job)
|
||||
else:
|
||||
# update sizes
|
||||
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
|
||||
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||
self._signal_job_downloading(install_job)
|
||||
|
||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
||||
self._logger.info(f"Model download complete: {download_job.source}")
|
||||
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache[download_job.source]
|
||||
|
||||
# are there any more active jobs left in this task?
|
||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
||||
self._signal_job_downloads_done(install_job)
|
||||
if install_job._do_install:
|
||||
self._put_in_queue(install_job)
|
||||
install_job = self._download_cache.pop(download_job.id)
|
||||
self._signal_job_downloads_done(install_job)
|
||||
self._put_in_queue(install_job) # this starts the installation and registration
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._download_cache.pop(download_job.source, None)
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
install_job = self._download_cache.pop(download_job.id)
|
||||
assert install_job is not None
|
||||
assert excp is not None
|
||||
install_job.set_error(excp)
|
||||
self._logger.error(
|
||||
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
|
||||
)
|
||||
self._cancel_download_parts(install_job)
|
||||
self._download_queue.cancel_job(download_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
|
||||
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||
with self._lock:
|
||||
install_job = self._download_cache.pop(download_job.source, None)
|
||||
install_job = self._download_cache.pop(download_job.id, None)
|
||||
if not install_job:
|
||||
return
|
||||
self._downloads_changed_event.set()
|
||||
self._logger.warning(f"Model download canceled: {download_job.source}")
|
||||
# if install job has already registered an error, then do not replace its status with cancelled
|
||||
if not install_job.errored:
|
||||
install_job.cancel()
|
||||
self._cancel_download_parts(install_job)
|
||||
|
||||
# Let other threads know that the number of downloads has changed
|
||||
self._downloads_changed_event.set()
|
||||
|
||||
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
|
||||
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
|
||||
# do not lock here because it gets called within a locked context
|
||||
for s in install_job.download_parts:
|
||||
self._download_queue.cancel_job(s)
|
||||
|
||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
||||
self._put_in_queue(install_job)
|
||||
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
# Internal methods that put events on the event bus
|
||||
# ------------------------------------------------------------------------------------------------
|
||||
@ -877,6 +839,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
|
||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||
if self._event_bus:
|
||||
assert job._download_job is not None
|
||||
parts: List[Dict[str, str | int]] = [
|
||||
{
|
||||
"url": str(x.source),
|
||||
@ -884,7 +847,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
"bytes": x.bytes,
|
||||
"total_bytes": x.total_bytes,
|
||||
}
|
||||
for x in job.download_parts
|
||||
for x in job._download_job.download_parts
|
||||
]
|
||||
assert job.bytes is not None
|
||||
assert job.total_bytes is not None
|
||||
@ -929,7 +892,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||
|
||||
@staticmethod
|
||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
||||
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
|
||||
"""
|
||||
Return a metadata fetcher appropriate for provided url.
|
||||
|
||||
This used to be more useful, but the number of supported model
|
||||
sources has been reduced to HuggingFace alone.
|
||||
"""
|
||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||
return HuggingFaceMetadataFetch
|
||||
raise ValueError(f"Unsupported model source: '{url}'")
|
||||
|
@ -40,6 +40,9 @@ class RemoteModelFile(BaseModel):
|
||||
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(str(self))
|
||||
|
||||
|
||||
class ModelMetadataBase(BaseModel):
|
||||
"""Base class for model metadata information."""
|
||||
|
Reference in New Issue
Block a user