mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor model_install to work with refactored download queue
This commit is contained in:
parent
287c679f7b
commit
f29c406fed
@ -42,9 +42,13 @@ MultiFileDownloadExceptionHandler = Callable[["MultiFileDownloadJob", Optional[E
|
|||||||
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
DownloadEventHandler = Union[SingleFileDownloadEventHandler, MultiFileDownloadEventHandler]
|
||||||
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
DownloadExceptionHandler = Union[SingleFileDownloadExceptionHandler, MultiFileDownloadExceptionHandler]
|
||||||
|
|
||||||
|
|
||||||
class DownloadJobBase(BaseModel):
|
class DownloadJobBase(BaseModel):
|
||||||
"""Base of classes to monitor and control downloads."""
|
"""Base of classes to monitor and control downloads."""
|
||||||
|
|
||||||
|
# automatically assigned on creation
|
||||||
|
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
||||||
|
|
||||||
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
dest: Path = Field(description="Initial destination of downloaded model on local disk; a directory or file path")
|
||||||
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
download_path: Optional[Path] = Field(default=None, description="Final location of downloaded file or directory")
|
||||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
status: DownloadJobStatus = Field(default=DownloadJobStatus.WAITING, description="Status of the download")
|
||||||
@ -149,8 +153,6 @@ class DownloadJob(DownloadJobBase):
|
|||||||
# required variables to be passed in on creation
|
# required variables to be passed in on creation
|
||||||
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
source: AnyHttpUrl = Field(description="Where to download from. Specific types specified in child classes.")
|
||||||
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
access_token: Optional[str] = Field(default=None, description="authorization token for protected resources")
|
||||||
# automatically assigned on creation
|
|
||||||
id: int = Field(description="Numeric ID of this job", default=-1) # default id is a sentinel
|
|
||||||
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
priority: int = Field(default=10, description="Queue priority; lower values are higher priority")
|
||||||
|
|
||||||
# set internally during download process
|
# set internally during download process
|
||||||
@ -225,7 +227,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def multifile_download(
|
def multifile_download(
|
||||||
self,
|
self,
|
||||||
parts: Set[RemoteModelFile],
|
parts: List[RemoteModelFile],
|
||||||
dest: Path,
|
dest: Path,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
submit_job: bool = True,
|
submit_job: bool = True,
|
||||||
@ -315,7 +317,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def cancel_job(self, job: DownloadJob) -> None:
|
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||||
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
"""Cancel the job, clearing partial downloads and putting it into ERROR state."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -325,7 +327,7 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
|
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||||
"""Wait until the indicated download job has reached a terminal state.
|
"""Wait until the indicated download job has reached a terminal state.
|
||||||
|
|
||||||
This will block until the indicated install job has completed,
|
This will block until the indicated install job has completed,
|
||||||
|
@ -113,9 +113,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
raise ServiceInactiveException(
|
raise ServiceInactiveException(
|
||||||
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
"The download service is not currently accepting requests. Please call start() to initialize the service."
|
||||||
)
|
)
|
||||||
with self._lock:
|
job.id = self._next_id()
|
||||||
job.id = self._next_job_id
|
|
||||||
self._next_job_id += 1
|
|
||||||
job.set_callbacks(
|
job.set_callbacks(
|
||||||
on_start=on_start,
|
on_start=on_start,
|
||||||
on_progress=on_progress,
|
on_progress=on_progress,
|
||||||
@ -161,16 +159,17 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
|
|
||||||
def multifile_download(
|
def multifile_download(
|
||||||
self,
|
self,
|
||||||
parts: Set[RemoteModelFile],
|
parts: List[RemoteModelFile],
|
||||||
dest: Path,
|
dest: Path,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
|
submit_job: bool = True,
|
||||||
on_start: Optional[DownloadEventHandler] = None,
|
on_start: Optional[DownloadEventHandler] = None,
|
||||||
on_progress: Optional[DownloadEventHandler] = None,
|
on_progress: Optional[DownloadEventHandler] = None,
|
||||||
on_complete: Optional[DownloadEventHandler] = None,
|
on_complete: Optional[DownloadEventHandler] = None,
|
||||||
on_cancelled: Optional[DownloadEventHandler] = None,
|
on_cancelled: Optional[DownloadEventHandler] = None,
|
||||||
on_error: Optional[DownloadExceptionHandler] = None,
|
on_error: Optional[DownloadExceptionHandler] = None,
|
||||||
) -> MultiFileDownloadJob:
|
) -> MultiFileDownloadJob:
|
||||||
mfdj = MultiFileDownloadJob(dest=dest)
|
mfdj = MultiFileDownloadJob(dest=dest, id=self._next_id())
|
||||||
mfdj.set_callbacks(
|
mfdj.set_callbacks(
|
||||||
on_start=on_start,
|
on_start=on_start,
|
||||||
on_progress=on_progress,
|
on_progress=on_progress,
|
||||||
@ -190,6 +189,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
)
|
)
|
||||||
mfdj.download_parts.add(job)
|
mfdj.download_parts.add(job)
|
||||||
self._download_part2parent[job.source] = mfdj
|
self._download_part2parent[job.source] = mfdj
|
||||||
|
if submit_job:
|
||||||
self.submit_multifile_download(mfdj)
|
self.submit_multifile_download(mfdj)
|
||||||
return mfdj
|
return mfdj
|
||||||
|
|
||||||
@ -208,6 +208,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
"""Wait for all jobs to complete."""
|
"""Wait for all jobs to complete."""
|
||||||
self._queue.join()
|
self._queue.join()
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
id = self._next_job_id
|
||||||
|
self._next_job_id += 1
|
||||||
|
return id
|
||||||
|
|
||||||
def list_jobs(self) -> List[DownloadJob]:
|
def list_jobs(self) -> List[DownloadJob]:
|
||||||
"""List all the jobs."""
|
"""List all the jobs."""
|
||||||
return list(self._jobs.values())
|
return list(self._jobs.values())
|
||||||
@ -229,7 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
except KeyError as excp:
|
except KeyError as excp:
|
||||||
raise UnknownJobIDException("Unrecognized job") from excp
|
raise UnknownJobIDException("Unrecognized job") from excp
|
||||||
|
|
||||||
def cancel_job(self, job: DownloadJob) -> None:
|
def cancel_job(self, job: DownloadJobBase) -> None:
|
||||||
"""
|
"""
|
||||||
Cancel the indicated job.
|
Cancel the indicated job.
|
||||||
|
|
||||||
@ -245,7 +251,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if not job.in_terminal_state:
|
if not job.in_terminal_state:
|
||||||
self.cancel_job(job)
|
self.cancel_job(job)
|
||||||
|
|
||||||
def wait_for_job(self, job: DownloadJob | MultiFileDownloadJob, timeout: int = 0) -> DownloadJob:
|
def wait_for_job(self, job: DownloadJobBase, timeout: int = 0) -> DownloadJobBase:
|
||||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
while not job.in_terminal_state:
|
while not job.in_terminal_state:
|
||||||
@ -468,6 +474,11 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if mf_job.waiting:
|
if mf_job.waiting:
|
||||||
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
mf_job.total_bytes = sum(x.total_bytes for x in mf_job.download_parts)
|
||||||
mf_job.status = DownloadJobStatus.RUNNING
|
mf_job.status = DownloadJobStatus.RUNNING
|
||||||
|
assert download_job.download_path is not None
|
||||||
|
path_relative_to_destdir = download_job.download_path.relative_to(mf_job.dest)
|
||||||
|
mf_job.download_path = (
|
||||||
|
mf_job.dest / path_relative_to_destdir.parts[0]
|
||||||
|
) # keep just the first component of the path
|
||||||
self._execute_cb(mf_job, "on_start")
|
self._execute_cb(mf_job, "on_start")
|
||||||
|
|
||||||
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
def _mfd_progress(self, download_job: DownloadJob) -> None:
|
||||||
|
@ -6,14 +6,14 @@ import traceback
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Set, Union
|
from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import ModelRecordServiceBase
|
from invokeai.app.services.model_records import ModelRecordServiceBase
|
||||||
@ -166,9 +166,6 @@ class ModelInstallJob(BaseModel):
|
|||||||
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
source_metadata: Optional[AnyModelRepoMetadata] = Field(
|
||||||
default=None, description="Metadata provided by the model source"
|
default=None, description="Metadata provided by the model source"
|
||||||
)
|
)
|
||||||
download_parts: Set[DownloadJob] = Field(
|
|
||||||
default_factory=set, description="Download jobs contributing to this install"
|
|
||||||
)
|
|
||||||
error: Optional[str] = Field(
|
error: Optional[str] = Field(
|
||||||
default=None, description="On an error condition, this field will contain the text of the exception"
|
default=None, description="On an error condition, this field will contain the text of the exception"
|
||||||
)
|
)
|
||||||
@ -177,7 +174,7 @@ class ModelInstallJob(BaseModel):
|
|||||||
)
|
)
|
||||||
# internal flags and transitory settings
|
# internal flags and transitory settings
|
||||||
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
_install_tmpdir: Optional[Path] = PrivateAttr(default=None)
|
||||||
_do_install: Optional[bool] = PrivateAttr(default=True)
|
_download_job: Optional[MultiFileDownloadJob] = PrivateAttr(default=None)
|
||||||
_exception: Optional[Exception] = PrivateAttr(default=None)
|
_exception: Optional[Exception] = PrivateAttr(default=None)
|
||||||
|
|
||||||
def set_error(self, e: Exception) -> None:
|
def set_error(self, e: Exception) -> None:
|
||||||
@ -408,21 +405,6 @@ class ModelInstallServiceBase(ABC):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def download_diffusers_model(
|
|
||||||
self,
|
|
||||||
source: HFModelSource,
|
|
||||||
download_to: Path,
|
|
||||||
) -> ModelInstallJob:
|
|
||||||
"""
|
|
||||||
Download, but do not install, a diffusers model.
|
|
||||||
|
|
||||||
:param source: An HFModelSource object containing a repo_id
|
|
||||||
:param download_to: Path to directory that will contain the downloaded model.
|
|
||||||
|
|
||||||
Returns: a ModelInstallJob
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
def get_job_by_source(self, source: ModelSource) -> List[ModelInstallJob]:
|
||||||
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
"""Return the ModelInstallJob(s) corresponding to the provided source."""
|
||||||
|
@ -9,7 +9,7 @@ from pathlib import Path
|
|||||||
from queue import Empty, Queue
|
from queue import Empty, Queue
|
||||||
from shutil import copyfile, copytree, move, rmtree
|
from shutil import copyfile, copytree, move, rmtree
|
||||||
from tempfile import mkdtemp
|
from tempfile import mkdtemp
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
@ -18,7 +18,7 @@ from pydantic.networks import AnyHttpUrl
|
|||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
from invokeai.app.services.download import DownloadQueueServiceBase, MultiFileDownloadJob, TqdmProgress
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase
|
||||||
@ -89,7 +89,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._downloads_changed_event = threading.Event()
|
self._downloads_changed_event = threading.Event()
|
||||||
self._install_completed_event = threading.Event()
|
self._install_completed_event = threading.Event()
|
||||||
self._download_queue = download_queue
|
self._download_queue = download_queue
|
||||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
self._download_cache: Dict[int, ModelInstallJob] = {}
|
||||||
self._running = False
|
self._running = False
|
||||||
self._session = session
|
self._session = session
|
||||||
self._install_thread: Optional[threading.Thread] = None
|
self._install_thread: Optional[threading.Thread] = None
|
||||||
@ -249,9 +249,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._install_jobs.append(install_job)
|
self._install_jobs.append(install_job)
|
||||||
return install_job
|
return install_job
|
||||||
|
|
||||||
def download_diffusers_model(self, source: HFModelSource, download_to: Path) -> ModelInstallJob:
|
|
||||||
return self._import_from_hf(source, download_path=download_to)
|
|
||||||
|
|
||||||
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
def list_jobs(self) -> List[ModelInstallJob]: # noqa D102
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
@ -291,8 +288,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||||
"""Cancel the indicated job."""
|
"""Cancel the indicated job."""
|
||||||
job.cancel()
|
job.cancel()
|
||||||
with self._lock:
|
self._logger.warning(f"Cancelling {job.source}")
|
||||||
self._cancel_download_parts(job)
|
if dj := job._download_job:
|
||||||
|
self._download_queue.cancel_job(dj)
|
||||||
|
|
||||||
def prune_jobs(self) -> None:
|
def prune_jobs(self) -> None:
|
||||||
"""Prune all completed and errored jobs."""
|
"""Prune all completed and errored jobs."""
|
||||||
@ -340,7 +338,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
legacy_config_path = stanza.get("config")
|
legacy_config_path = stanza.get("config")
|
||||||
if legacy_config_path:
|
if legacy_config_path:
|
||||||
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
# In v3, these paths were relative to the root. Migrate them to be relative to the legacy_conf_dir.
|
||||||
legacy_config_path: Path = self._app_config.root_path / legacy_config_path
|
legacy_config_path = self._app_config.root_path / legacy_config_path
|
||||||
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
if legacy_config_path.is_relative_to(self._app_config.legacy_conf_path):
|
||||||
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
legacy_config_path = legacy_config_path.relative_to(self._app_config.legacy_conf_path)
|
||||||
config["config_path"] = str(legacy_config_path)
|
config["config_path"] = str(legacy_config_path)
|
||||||
@ -476,16 +474,19 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
job.config_out = self.record_store.get_model(key)
|
job.config_out = self.record_store.get_model(key)
|
||||||
self._signal_job_completed(job)
|
self._signal_job_completed(job)
|
||||||
|
|
||||||
def _set_error(self, job: ModelInstallJob, excp: Exception) -> None:
|
def _set_error(self, install_job: ModelInstallJob, excp: Exception) -> None:
|
||||||
if any(x.content_type is not None and "text/html" in x.content_type for x in job.download_parts):
|
download_job = install_job._download_job
|
||||||
job.set_error(
|
if download_job and any(
|
||||||
|
x.content_type is not None and "text/html" in x.content_type for x in download_job.download_parts
|
||||||
|
):
|
||||||
|
install_job.set_error(
|
||||||
InvalidModelConfigException(
|
InvalidModelConfigException(
|
||||||
f"At least one file in {job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
f"At least one file in {install_job.local_path} is an HTML page, not a model. This can happen when an access token is required to download."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
job.set_error(excp)
|
install_job.set_error(excp)
|
||||||
self._signal_job_errored(job)
|
self._signal_job_errored(install_job)
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the models directory
|
# Internal functions that manage the models directory
|
||||||
@ -511,7 +512,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
This is typically only used during testing with a new DB or when using the memory DB, because those are the
|
||||||
only situations in which we may have orphaned models in the models directory.
|
only situations in which we may have orphaned models in the models directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
installed_model_paths = {
|
installed_model_paths = {
|
||||||
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
(self._app_config.models_path / x.path).resolve() for x in self.record_store.all_models()
|
||||||
}
|
}
|
||||||
@ -648,7 +648,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self,
|
self,
|
||||||
source: HFModelSource,
|
source: HFModelSource,
|
||||||
config: Optional[Dict[str, Any]] = None,
|
config: Optional[Dict[str, Any]] = None,
|
||||||
download_path: Optional[Path] = None,
|
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
# Add user's cached access token to HuggingFace requests
|
# Add user's cached access token to HuggingFace requests
|
||||||
source.access_token = source.access_token or HfFolder.get_token()
|
source.access_token = source.access_token or HfFolder.get_token()
|
||||||
@ -668,7 +667,6 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
config=config,
|
config=config,
|
||||||
remote_files=remote_files,
|
remote_files=remote_files,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
download_path=download_path,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _import_from_url(
|
def _import_from_url(
|
||||||
@ -704,14 +702,10 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
remote_files: List[RemoteModelFile],
|
remote_files: List[RemoteModelFile],
|
||||||
metadata: Optional[AnyModelRepoMetadata],
|
metadata: Optional[AnyModelRepoMetadata],
|
||||||
config: Optional[Dict[str, Any]],
|
config: Optional[Dict[str, Any]],
|
||||||
download_path: Optional[Path] = None, # if defined, download only - don't install!
|
|
||||||
) -> ModelInstallJob:
|
) -> ModelInstallJob:
|
||||||
# TODO: Replace with tempfile.tmpdir() when multithreading is cleaned up.
|
|
||||||
# Currently the tmpdir isn't automatically removed at exit because it is
|
|
||||||
# being held in a daemon thread.
|
|
||||||
if len(remote_files) == 0:
|
if len(remote_files) == 0:
|
||||||
raise ValueError(f"{source}: No downloadable files found")
|
raise ValueError(f"{source}: No downloadable files found")
|
||||||
destdir = download_path or Path(
|
destdir = Path(
|
||||||
mkdtemp(
|
mkdtemp(
|
||||||
dir=self._app_config.models_path,
|
dir=self._app_config.models_path,
|
||||||
prefix=TMPDIR_PREFIX,
|
prefix=TMPDIR_PREFIX,
|
||||||
@ -726,6 +720,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
bytes=0,
|
bytes=0,
|
||||||
total_bytes=0,
|
total_bytes=0,
|
||||||
)
|
)
|
||||||
|
# remember the temporary directory for later removal
|
||||||
|
install_job._install_tmpdir = destdir
|
||||||
|
|
||||||
# In the event that there is a subfolder specified in the source,
|
# In the event that there is a subfolder specified in the source,
|
||||||
# we need to remove it from the destination path in order to avoid
|
# we need to remove it from the destination path in order to avoid
|
||||||
# creating unwanted subfolders
|
# creating unwanted subfolders
|
||||||
@ -739,39 +736,31 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# we remember the path up to the top of the destdir so that it may be
|
# we remember the path up to the top of the destdir so that it may be
|
||||||
# removed safely at the end of the install process.
|
# removed safely at the end of the install process.
|
||||||
install_job._install_tmpdir = destdir
|
install_job._install_tmpdir = destdir
|
||||||
install_job._do_install = download_path is None
|
|
||||||
assert install_job.total_bytes is not None # to avoid type checking complaints in the loop below
|
|
||||||
|
|
||||||
files_string = "file" if len(remote_files) == 1 else "file"
|
parts: List[RemoteModelFile] = []
|
||||||
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
|
||||||
self._logger.debug(f"remote_files={remote_files}")
|
|
||||||
for model_file in remote_files:
|
for model_file in remote_files:
|
||||||
url = model_file.url
|
assert install_job.total_bytes is not None
|
||||||
path = root / model_file.path.relative_to(subfolder)
|
assert model_file.size is not None
|
||||||
self._logger.debug(f"Downloading {url} => {path}")
|
|
||||||
install_job.total_bytes += model_file.size
|
install_job.total_bytes += model_file.size
|
||||||
assert hasattr(source, "access_token")
|
parts.append(RemoteModelFile(url=model_file.url, path=model_file.path.relative_to(subfolder)))
|
||||||
dest = destdir / path.parent
|
multifile_job = self._download_queue.multifile_download(
|
||||||
dest.mkdir(parents=True, exist_ok=True)
|
parts=parts,
|
||||||
download_job = DownloadJob(
|
dest=destdir,
|
||||||
source=url,
|
|
||||||
dest=dest,
|
|
||||||
access_token=source.access_token,
|
access_token=source.access_token,
|
||||||
)
|
submit_job=False,
|
||||||
self._download_cache[download_job.source] = install_job # matches a download job to an install job
|
|
||||||
install_job.download_parts.add(download_job)
|
|
||||||
|
|
||||||
# only start the jobs once install_job.download_parts is fully populated
|
|
||||||
for download_job in install_job.download_parts:
|
|
||||||
self._download_queue.submit_download_job(
|
|
||||||
download_job,
|
|
||||||
on_start=self._download_started_callback,
|
on_start=self._download_started_callback,
|
||||||
on_progress=self._download_progress_callback,
|
on_progress=self._download_progress_callback,
|
||||||
on_complete=self._download_complete_callback,
|
on_complete=self._download_complete_callback,
|
||||||
on_error=self._download_error_callback,
|
on_error=self._download_error_callback,
|
||||||
on_cancelled=self._download_cancelled_callback,
|
on_cancelled=self._download_cancelled_callback,
|
||||||
)
|
)
|
||||||
|
self._download_cache[multifile_job.id] = install_job
|
||||||
|
install_job._download_job = multifile_job
|
||||||
|
|
||||||
|
files_string = "file" if len(remote_files) == 1 else "file"
|
||||||
|
self._logger.info(f"Queuing model install: {source} ({len(remote_files)} {files_string})")
|
||||||
|
self._logger.debug(f"remote_files={remote_files}")
|
||||||
|
self._download_queue.submit_multifile_download(multifile_job)
|
||||||
return install_job
|
return install_job
|
||||||
|
|
||||||
def _stat_size(self, path: Path) -> int:
|
def _stat_size(self, path: Path) -> int:
|
||||||
@ -786,86 +775,59 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Callbacks are executed by the download queue in a separate thread
|
# Callbacks are executed by the download queue in a separate thread
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
def _download_started_callback(self, download_job: DownloadJob) -> None:
|
def _download_started_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
self._logger.info(f"Model download started: {download_job.source}")
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
install_job = self._download_cache[download_job.id]
|
||||||
install_job.status = InstallStatus.DOWNLOADING
|
install_job.status = InstallStatus.DOWNLOADING
|
||||||
|
|
||||||
assert download_job.download_path
|
assert download_job.download_path
|
||||||
if install_job.local_path == install_job._install_tmpdir:
|
if install_job.local_path == install_job._install_tmpdir: # first time
|
||||||
partial_path = download_job.download_path.relative_to(install_job._install_tmpdir)
|
install_job.local_path = download_job.download_path
|
||||||
dest_name = partial_path.parts[0]
|
install_job.total_bytes = download_job.total_bytes
|
||||||
install_job.local_path = install_job._install_tmpdir / dest_name
|
|
||||||
|
|
||||||
# Update the total bytes count for remote sources.
|
def _download_progress_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
if not install_job.total_bytes:
|
|
||||||
install_job.total_bytes = sum(x.total_bytes for x in install_job.download_parts)
|
|
||||||
|
|
||||||
def _download_progress_callback(self, download_job: DownloadJob) -> None:
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
install_job = self._download_cache[download_job.id]
|
||||||
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
if install_job.cancelled: # This catches the case in which the caller directly calls job.cancel()
|
||||||
self._cancel_download_parts(install_job)
|
self._download_queue.cancel_job(download_job)
|
||||||
else:
|
else:
|
||||||
# update sizes
|
# update sizes
|
||||||
install_job.bytes = sum(x.bytes for x in install_job.download_parts)
|
install_job.bytes = sum(x.bytes for x in download_job.download_parts)
|
||||||
self._signal_job_downloading(install_job)
|
self._signal_job_downloading(install_job)
|
||||||
|
|
||||||
def _download_complete_callback(self, download_job: DownloadJob) -> None:
|
def _download_complete_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
self._logger.info(f"Model download complete: {download_job.source}")
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache[download_job.source]
|
install_job = self._download_cache.pop(download_job.id)
|
||||||
|
|
||||||
# are there any more active jobs left in this task?
|
|
||||||
if install_job.downloading and all(x.complete for x in install_job.download_parts):
|
|
||||||
self._signal_job_downloads_done(install_job)
|
self._signal_job_downloads_done(install_job)
|
||||||
if install_job._do_install:
|
self._put_in_queue(install_job) # this starts the installation and registration
|
||||||
self._put_in_queue(install_job)
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._download_cache.pop(download_job.source, None)
|
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _download_error_callback(self, download_job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def _download_error_callback(self, download_job: MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.source, None)
|
install_job = self._download_cache.pop(download_job.id)
|
||||||
assert install_job is not None
|
assert install_job is not None
|
||||||
assert excp is not None
|
assert excp is not None
|
||||||
install_job.set_error(excp)
|
install_job.set_error(excp)
|
||||||
self._logger.error(
|
self._download_queue.cancel_job(download_job)
|
||||||
f"Cancelling {install_job.source} due to an error while downloading {download_job.source}: {str(excp)}"
|
|
||||||
)
|
|
||||||
self._cancel_download_parts(install_job)
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _download_cancelled_callback(self, download_job: DownloadJob) -> None:
|
def _download_cancelled_callback(self, download_job: MultiFileDownloadJob) -> None:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
install_job = self._download_cache.pop(download_job.source, None)
|
install_job = self._download_cache.pop(download_job.id, None)
|
||||||
if not install_job:
|
if not install_job:
|
||||||
return
|
return
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
self._logger.warning(f"Model download canceled: {download_job.source}")
|
|
||||||
# if install job has already registered an error, then do not replace its status with cancelled
|
# if install job has already registered an error, then do not replace its status with cancelled
|
||||||
if not install_job.errored:
|
if not install_job.errored:
|
||||||
install_job.cancel()
|
install_job.cancel()
|
||||||
self._cancel_download_parts(install_job)
|
|
||||||
|
|
||||||
# Let other threads know that the number of downloads has changed
|
# Let other threads know that the number of downloads has changed
|
||||||
self._downloads_changed_event.set()
|
self._downloads_changed_event.set()
|
||||||
|
|
||||||
def _cancel_download_parts(self, install_job: ModelInstallJob) -> None:
|
|
||||||
# on multipart downloads, _cancel_components() will get called repeatedly from the download callbacks
|
|
||||||
# do not lock here because it gets called within a locked context
|
|
||||||
for s in install_job.download_parts:
|
|
||||||
self._download_queue.cancel_job(s)
|
|
||||||
|
|
||||||
if all(x.in_terminal_state for x in install_job.download_parts):
|
|
||||||
# When all parts have reached their terminal state, we finalize the job to clean up the temporary directory and other resources
|
|
||||||
self._put_in_queue(install_job)
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
# Internal methods that put events on the event bus
|
# Internal methods that put events on the event bus
|
||||||
# ------------------------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------------------------
|
||||||
@ -877,6 +839,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
|
|
||||||
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
def _signal_job_downloading(self, job: ModelInstallJob) -> None:
|
||||||
if self._event_bus:
|
if self._event_bus:
|
||||||
|
assert job._download_job is not None
|
||||||
parts: List[Dict[str, str | int]] = [
|
parts: List[Dict[str, str | int]] = [
|
||||||
{
|
{
|
||||||
"url": str(x.source),
|
"url": str(x.source),
|
||||||
@ -884,7 +847,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
"bytes": x.bytes,
|
"bytes": x.bytes,
|
||||||
"total_bytes": x.total_bytes,
|
"total_bytes": x.total_bytes,
|
||||||
}
|
}
|
||||||
for x in job.download_parts
|
for x in job._download_job.download_parts
|
||||||
]
|
]
|
||||||
assert job.bytes is not None
|
assert job.bytes is not None
|
||||||
assert job.total_bytes is not None
|
assert job.total_bytes is not None
|
||||||
@ -929,7 +892,13 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
self._event_bus.emit_model_install_cancelled(str(job.source), id=job.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_fetcher_from_url(url: str) -> ModelMetadataFetchBase:
|
def get_fetcher_from_url(url: str) -> Type[ModelMetadataFetchBase]:
|
||||||
|
"""
|
||||||
|
Return a metadata fetcher appropriate for provided url.
|
||||||
|
|
||||||
|
This used to be more useful, but the number of supported model
|
||||||
|
sources has been reduced to HuggingFace alone.
|
||||||
|
"""
|
||||||
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
if re.match(r"^https?://huggingface.co/[^/]+/[^/]+$", url.lower()):
|
||||||
return HuggingFaceMetadataFetch
|
return HuggingFaceMetadataFetch
|
||||||
raise ValueError(f"Unsupported model source: '{url}'")
|
raise ValueError(f"Unsupported model source: '{url}'")
|
||||||
|
@ -40,6 +40,9 @@ class RemoteModelFile(BaseModel):
|
|||||||
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
size: Optional[int] = Field(description="The size of this file, in bytes", default=0)
|
||||||
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
sha256: Optional[str] = Field(description="SHA256 hash of this model (not always available)", default=None)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
|
||||||
class ModelMetadataBase(BaseModel):
|
class ModelMetadataBase(BaseModel):
|
||||||
"""Base class for model metadata information."""
|
"""Base class for model metadata information."""
|
||||||
|
@ -4,79 +4,33 @@ import re
|
|||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, Optional
|
from typing import Any, Generator, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from requests_testadapter import TestAdapter, TestSession
|
from requests_testadapter import TestAdapter
|
||||||
|
|
||||||
from invokeai.app.services.config import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService, MultiFileDownloadJob
|
||||||
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, RemoteModelFile
|
from invokeai.backend.model_manager.metadata import HuggingFaceMetadataFetch, ModelMetadataWithFiles, RemoteModelFile
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
from tests.test_nodes import TestEventService
|
from tests.test_nodes import TestEventService
|
||||||
|
|
||||||
# Prevent pytest deprecation warnings
|
# Prevent pytest deprecation warnings
|
||||||
TestAdapter.__test__ = False # type: ignore
|
TestAdapter.__test__ = False
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def session() -> Session:
|
|
||||||
sess = TestSession()
|
|
||||||
for i in ["12345", "9999", "54321"]:
|
|
||||||
content = (
|
|
||||||
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
|
||||||
) # for pause tests, must make content large
|
|
||||||
sess.mount(
|
|
||||||
f"http://www.civitai.com/models/{i}",
|
|
||||||
TestAdapter(
|
|
||||||
content,
|
|
||||||
headers={
|
|
||||||
"Content-Length": len(content),
|
|
||||||
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
sess.mount(
|
|
||||||
"http://www.huggingface.co/foo.txt",
|
|
||||||
TestAdapter(
|
|
||||||
content,
|
|
||||||
headers={
|
|
||||||
"Content-Length": len(content),
|
|
||||||
"Content-Disposition": 'filename="foo.safetensors"',
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# here are some malformed URLs to test
|
|
||||||
# missing the content length
|
|
||||||
sess.mount(
|
|
||||||
"http://www.civitai.com/models/missing",
|
|
||||||
TestAdapter(
|
|
||||||
b"Missing content length",
|
|
||||||
headers={
|
|
||||||
"Content-Disposition": 'filename="missing.txt"',
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
# not found test
|
|
||||||
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
|
||||||
|
|
||||||
return sess
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=10, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
def test_basic_queue_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
events = set()
|
events = set()
|
||||||
|
|
||||||
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
def event_handler(job: DownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
events.add(job.status)
|
events.add(job.status)
|
||||||
|
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
requests_session=session,
|
requests_session=mm2_session,
|
||||||
)
|
)
|
||||||
queue.start()
|
queue.start()
|
||||||
job = queue.download(
|
job = queue.download(
|
||||||
@ -92,6 +46,7 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
|||||||
queue.join()
|
queue.join()
|
||||||
|
|
||||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||||
|
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||||
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||||
|
|
||||||
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
assert events == {DownloadJobStatus.RUNNING, DownloadJobStatus.COMPLETED}
|
||||||
@ -99,9 +54,9 @@ def test_basic_queue_download(tmp_path: Path, session: Session) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=10, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_errors(tmp_path: Path, session: Session) -> None:
|
def test_errors(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
requests_session=session,
|
requests_session=mm2_session,
|
||||||
)
|
)
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
@ -121,10 +76,10 @@ def test_errors(tmp_path: Path, session: Session) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=10, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_event_bus(tmp_path: Path, session: Session) -> None:
|
def test_event_bus(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
event_bus = TestEventService()
|
event_bus = TestEventService()
|
||||||
|
|
||||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||||
queue.start()
|
queue.start()
|
||||||
queue.download(
|
queue.download(
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
@ -157,9 +112,9 @@ def test_event_bus(tmp_path: Path, session: Session) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=10, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
def test_broken_callbacks(tmp_path: Path, mm2_session: Session, capsys) -> None:
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
requests_session=session,
|
requests_session=mm2_session,
|
||||||
)
|
)
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
@ -189,10 +144,10 @@ def test_broken_callbacks(tmp_path: Path, session: Session, capsys) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=10, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_cancel(tmp_path: Path, session: Session) -> None:
|
def test_cancel(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
event_bus = TestEventService()
|
event_bus = TestEventService()
|
||||||
|
|
||||||
queue = DownloadQueueService(requests_session=session, event_bus=event_bus)
|
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||||
queue.start()
|
queue.start()
|
||||||
|
|
||||||
cancelled = False
|
cancelled = False
|
||||||
@ -204,9 +159,6 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
|||||||
nonlocal cancelled
|
nonlocal cancelled
|
||||||
cancelled = True
|
cancelled = True
|
||||||
|
|
||||||
def handler(signum, frame):
|
|
||||||
raise TimeoutError("Join took too long to return")
|
|
||||||
|
|
||||||
job = queue.download(
|
job = queue.download(
|
||||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
dest=tmp_path,
|
dest=tmp_path,
|
||||||
@ -223,14 +175,15 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
|||||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=10, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
events = set()
|
events = set()
|
||||||
|
|
||||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
print(f"bytes = {job.bytes}")
|
|
||||||
events.add(job.status)
|
events.add(job.status)
|
||||||
|
|
||||||
queue = DownloadQueueService(
|
queue = DownloadQueueService(
|
||||||
@ -251,6 +204,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
|||||||
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||||
assert job.bytes > 0, "expected download bytes to be positive"
|
assert job.bytes > 0, "expected download bytes to be positive"
|
||||||
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||||
|
assert job.download_path == tmp_path / "sdxl-turbo"
|
||||||
assert Path(
|
assert Path(
|
||||||
tmp_path, "sdxl-turbo/model_index.json"
|
tmp_path, "sdxl-turbo/model_index.json"
|
||||||
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
).exists(), f"expected {tmp_path}/sdxl-turbo/model_inded.json to exist"
|
||||||
@ -266,6 +220,7 @@ def test_multifile_download(tmp_path: Path, mm2_session: Session) -> None:
|
|||||||
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
events = set()
|
events = set()
|
||||||
|
|
||||||
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
def event_handler(job: DownloadJob | MultiFileDownloadJob, excp: Optional[Exception] = None) -> None:
|
||||||
@ -289,13 +244,14 @@ def test_multifile_download_error(tmp_path: Path, mm2_session: Session) -> None:
|
|||||||
queue.join()
|
queue.join()
|
||||||
|
|
||||||
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
assert job.status == DownloadJobStatus("error"), "expected job status to be errored"
|
||||||
|
assert job.error_type is not None
|
||||||
assert "HTTPError(NOT FOUND)" in job.error_type
|
assert "HTTPError(NOT FOUND)" in job.error_type
|
||||||
assert DownloadJobStatus.ERROR in events
|
assert DownloadJobStatus.ERROR in events
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=10, method="thread")
|
@pytest.mark.timeout(timeout=10, method="thread")
|
||||||
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) -> None:
|
def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch: Any) -> None:
|
||||||
event_bus = TestEventService()
|
event_bus = TestEventService()
|
||||||
|
|
||||||
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
queue = DownloadQueueService(requests_session=mm2_session, event_bus=event_bus)
|
||||||
@ -307,11 +263,9 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
|
|||||||
nonlocal cancelled
|
nonlocal cancelled
|
||||||
cancelled = True
|
cancelled = True
|
||||||
|
|
||||||
def handler(signum, frame):
|
|
||||||
raise TimeoutError("Join took too long to return")
|
|
||||||
|
|
||||||
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
fetcher = HuggingFaceMetadataFetch(mm2_session)
|
||||||
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
metadata = fetcher.from_id("stabilityai/sdxl-turbo")
|
||||||
|
assert isinstance(metadata, ModelMetadataWithFiles)
|
||||||
|
|
||||||
job = queue.multifile_download(
|
job = queue.multifile_download(
|
||||||
parts=metadata.download_urls(session=mm2_session),
|
parts=metadata.download_urls(session=mm2_session),
|
||||||
@ -327,6 +281,29 @@ def test_multifile_cancel(tmp_path: Path, mm2_session: Session, monkeypatch) ->
|
|||||||
assert "download_cancelled" in [x.event_name for x in events]
|
assert "download_cancelled" in [x.event_name for x in events]
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
def test_multifile_onefile(tmp_path: Path, mm2_session: Session) -> None:
|
||||||
|
queue = DownloadQueueService(
|
||||||
|
requests_session=mm2_session,
|
||||||
|
)
|
||||||
|
queue.start()
|
||||||
|
job = queue.multifile_download(
|
||||||
|
parts=[
|
||||||
|
RemoteModelFile(url=AnyHttpUrl("http://www.civitai.com/models/12345"), path=Path("mock12345.safetensors"))
|
||||||
|
],
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
assert isinstance(job, MultiFileDownloadJob), "expected the job to be of type MultiFileDownloadJobBase"
|
||||||
|
queue.join()
|
||||||
|
|
||||||
|
assert job.status == DownloadJobStatus("completed"), "expected job status to be completed"
|
||||||
|
assert job.bytes > 0, "expected download bytes to be positive"
|
||||||
|
assert job.bytes == job.total_bytes, "expected download bytes to equal total bytes"
|
||||||
|
assert job.download_path == tmp_path / "mock12345.safetensors"
|
||||||
|
assert Path(tmp_path, "mock12345.safetensors").exists(), f"expected {tmp_path}/mock12345.safetensors to exist"
|
||||||
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def clear_config() -> Generator[None, None, None]:
|
def clear_config() -> Generator[None, None, None]:
|
||||||
try:
|
try:
|
||||||
@ -335,11 +312,11 @@ def clear_config() -> Generator[None, None, None]:
|
|||||||
get_config.cache_clear()
|
get_config.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
def test_tokens(tmp_path: Path, session: Session):
|
def test_tokens(tmp_path: Path, mm2_session: Session):
|
||||||
with clear_config():
|
with clear_config():
|
||||||
config = get_config()
|
config = get_config()
|
||||||
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||||
queue = DownloadQueueService(requests_session=session)
|
queue = DownloadQueueService(requests_session=mm2_session)
|
||||||
queue.start()
|
queue.start()
|
||||||
# this one has an access token assigned
|
# this one has an access token assigned
|
||||||
job1 = queue.download(
|
job1 = queue.download(
|
||||||
|
@ -286,14 +286,36 @@ def test_huggingface_install(mm2_installer: ModelInstallServiceBase, mm2_app_con
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.timeout(timeout=20, method="thread")
|
@pytest.mark.timeout(timeout=20, method="thread")
|
||||||
def test_huggingface_download(mm2_installer: ModelInstallServiceBase, tmp_path: Path) -> None:
|
def test_huggingface_repo_id(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
|
# TODO: Test subfolder download
|
||||||
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
source = HFModelSource(repo_id="stabilityai/sdxl-turbo", variant=ModelRepoVariant.Default)
|
||||||
job = mm2_installer.download_diffusers_model(source, tmp_path)
|
|
||||||
mm2_installer.wait_for_installs(timeout=5)
|
bus = mm2_installer.event_bus
|
||||||
print(job.local_path)
|
store = mm2_installer.record_store
|
||||||
assert job.status == InstallStatus.DOWNLOADS_DONE
|
assert isinstance(bus, EventServiceBase)
|
||||||
assert (tmp_path / "sdxl-turbo").exists()
|
assert store is not None
|
||||||
assert (tmp_path / "sdxl-turbo" / "model_index.json").exists()
|
|
||||||
|
job = mm2_installer.import_model(source)
|
||||||
|
job_list = mm2_installer.wait_for_installs(timeout=10)
|
||||||
|
assert len(job_list) == 1
|
||||||
|
assert job.complete
|
||||||
|
assert job.config_out
|
||||||
|
|
||||||
|
key = job.config_out.key
|
||||||
|
model_record = store.get_model(key)
|
||||||
|
assert (mm2_app_config.models_path / model_record.path).exists()
|
||||||
|
assert model_record.type == ModelType.Main
|
||||||
|
assert model_record.format == ModelFormat.Diffusers
|
||||||
|
|
||||||
|
assert hasattr(bus, "events") # the dummyeventservice has this
|
||||||
|
assert len(bus.events) >= 3
|
||||||
|
event_names = {x.event_name for x in bus.events}
|
||||||
|
assert event_names == {
|
||||||
|
"model_install_downloading",
|
||||||
|
"model_install_downloads_done",
|
||||||
|
"model_install_running",
|
||||||
|
"model_install_completed",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
def test_404_download(mm2_installer: ModelInstallServiceBase, mm2_app_config: InvokeAIAppConfig) -> None:
|
||||||
@ -327,7 +349,6 @@ def test_other_error_during_install(
|
|||||||
assert job.error == "Test error"
|
assert job.error == "Test error"
|
||||||
|
|
||||||
|
|
||||||
# TODO: Fix bug in model install causing jobs to get installed multiple times then uncomment this test
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model_params",
|
"model_params",
|
||||||
[
|
[
|
||||||
|
@ -317,4 +317,45 @@ def mm2_session(embedding_file: Path, diffusers_dir: Path) -> Session:
|
|||||||
},
|
},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for i in ["12345", "9999", "54321"]:
|
||||||
|
content = (
|
||||||
|
b"I am a safetensors file " + bytearray(i, "utf-8") + bytearray(32_000)
|
||||||
|
) # for pause tests, must make content large
|
||||||
|
sess.mount(
|
||||||
|
f"http://www.civitai.com/models/{i}",
|
||||||
|
TestAdapter(
|
||||||
|
content,
|
||||||
|
headers={
|
||||||
|
"Content-Length": len(content),
|
||||||
|
"Content-Disposition": f'filename="mock{i}.safetensors"',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
sess.mount(
|
||||||
|
"http://www.huggingface.co/foo.txt",
|
||||||
|
TestAdapter(
|
||||||
|
content,
|
||||||
|
headers={
|
||||||
|
"Content-Length": len(content),
|
||||||
|
"Content-Disposition": 'filename="foo.safetensors"',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# here are some malformed URLs to test
|
||||||
|
# missing the content length
|
||||||
|
sess.mount(
|
||||||
|
"http://www.civitai.com/models/missing",
|
||||||
|
TestAdapter(
|
||||||
|
b"Missing content length",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": 'filename="missing.txt"',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# not found test
|
||||||
|
sess.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||||
|
|
||||||
return sess
|
return sess
|
||||||
|
Loading…
Reference in New Issue
Block a user