refactor download queue jobs

This commit is contained in:
Lincoln Stein
2023-10-08 16:39:23 -04:00
parent a64a34b49a
commit e5b2bc8532
7 changed files with 413 additions and 319 deletions

View File

@ -5,7 +5,7 @@ from .base import ( # noqa F401
DownloadJobBase, DownloadJobBase,
DownloadJobStatus, DownloadJobStatus,
DownloadQueueBase, DownloadQueueBase,
ModelSourceMetadata,
UnknownJobIDException, UnknownJobIDException,
) )
from .queue import DownloadQueue # noqa F401 from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue # noqa F401

View File

@ -30,17 +30,6 @@ class UnknownJobIDException(Exception):
"""Raised when an invalid Job is referenced.""" """Raised when an invalid Job is referenced."""
class ModelSourceMetadata(BaseModel):
"""Information collected on a downloadable model from its source site."""
name: Optional[str] = Field(description="Human-readable name of this model")
author: Optional[str] = Field(description="Author/creator of the model")
description: Optional[str] = Field(description="Description of the model")
license: Optional[str] = Field(description="Model license terms")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
tags: Optional[List[str]] = Field(description="List of descriptive tags")
DownloadEventHandler = Callable[["DownloadJobBase"], None] DownloadEventHandler = Callable[["DownloadJobBase"], None]
@ -67,9 +56,6 @@ class DownloadJobBase(BaseModel):
description="if true, then preserve partial downloads when cancelled or errored", default=False description="if true, then preserve partial downloads when cancelled or errored", default=False
) )
error: Optional[Exception] = Field(default=None, description="Exception that caused an error") error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
metadata: ModelSourceMetadata = Field(
description="Metadata describing download contents", default_factory=ModelSourceMetadata
)
def add_event_handler(self, handler: DownloadEventHandler): def add_event_handler(self, handler: DownloadEventHandler):
"""Add an event handler to the end of the handlers list.""" """Add an event handler to the end of the handlers list."""
@ -134,7 +120,7 @@ class DownloadQueueBase(ABC):
filename: Optional[Path] = None, filename: Optional[Path] = None,
variant: Optional[str] = None, variant: Optional[str] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None, event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase: ) -> DownloadJobBase:
""" """
Create and submit a download job. Create and submit a download job.
@ -274,3 +260,17 @@ class DownloadQueueBase(ABC):
no longer recognize the job. no longer recognize the job.
""" """
pass pass
@abstractmethod
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method."""
pass
@abstractmethod
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
"""
Given a job, translate its source field into a downloadable URL.
Intended to be subclassed to cover various source types.
"""
pass

View File

@ -0,0 +1,349 @@
import re
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
from huggingface_hub import HfApi, hf_hub_url
from pydantic import BaseModel, Field, parse_obj_as, validator
from pydantic.networks import AnyHttpUrl
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase
from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadQueue
# regular expressions used to dispatch appropriate downloaders and metadata retrievers
# endpoint for civitai get-model API
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
# Regular expressions to describe repo_ids and http urls
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
class ModelSourceMetadata(BaseModel):
"""Information collected on a downloadable model from its source site."""
name: Optional[str] = Field(description="Human-readable name of this model")
author: Optional[str] = Field(description="Author/creator of the model")
description: Optional[str] = Field(description="Description of the model")
license: Optional[str] = Field(description="Model license terms")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
tags: Optional[List[str]] = Field(description="List of descriptive tags")
class DownloadJobWithMetadata(DownloadJobRemoteSource):
"""A remote download that has metadata associated with it."""
metadata: ModelSourceMetadata = Field(
description="Metadata describing the model, derived from source", default_factory=ModelSourceMetadata
)
class DownloadJobRepoID(DownloadJobWithMetadata):
"""Download repo ids."""
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
subfolder: Optional[str] = Field(
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
)
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
subqueue: Optional[DownloadQueueBase] = Field(
description="a subqueue used for downloading the individual files in the repo_id", default=None
)
@validator("source")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def cleanup(self, preserve_partial_downloads: bool = False):
"""Perform action when job is completed."""
if self.subqueue:
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
self.subqueue.release()
class ModelDownloadQueue(DownloadQueue):
"""Subclass of DownloadQueue, able to retrieve metadata from HuggingFace and Civitai."""
def create_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
start: bool = True,
priority: int = 10,
filename: Optional[Path] = None,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase:
"""Create a download job and return its ID."""
cls: Optional[Type[DownloadJobBase]] = None
kwargs: Dict[str, Optional[str]] = dict()
if re.match(HTTP_RE, str(source)):
cls = DownloadJobWithMetadata
kwargs.update(access_token=access_token)
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = DownloadJobRepoID
kwargs.update(
variant=variant,
access_token=access_token,
)
if cls:
job = cls(
source=source,
destination=Path(destdir) / (filename or "."),
event_handlers=event_handlers,
priority=priority,
**kwargs,
)
return self.submit_download_job(job, start)
else:
return super().create_download_job(
source=source,
destdir=destdir,
start=start,
priority=priority,
filename=filename,
variant=variant,
access_token=access_token,
event_handlers=event_handlers,
)
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method."""
if isinstance(job, DownloadJobRepoID):
return self._download_repoid
elif isinstance(job, DownloadJobWithMetadata):
return self._download_with_resume
else:
return super().select_downloader(job)
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
"""
Fetch metadata from certain well-known URLs.
The metadata will be stashed in job.metadata, if found
Return the download URL.
"""
assert isinstance(job, DownloadJobWithMetadata)
metadata = job.metadata
url = job.source
metadata_url = url
model = None
# a Civitai download URL
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
# a Civitai model page with the version
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
model = match.group(1)
version = int(match.group(2))
# and without
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
model = match.group(1)
version = None
if not model:
return parse_obj_as(AnyHttpUrl, url)
if model:
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
metadata.license = (
metadata.license
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
)
if version:
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
version_data = versions[0]
else:
version_data = resp["modelVersions"][0] # first one
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
if version_data.get("trainedWords")
else version_data.get("description")
)
download_url = version_data["downloadUrl"]
# return the download url
return download_url
def _download_repoid(self, job: DownloadJobBase) -> None:
"""Download a job that holds a huggingface repoid."""
def subdownload_event(subjob: DownloadJobBase):
assert isinstance(subjob, DownloadJobRemoteSource)
assert isinstance(job, DownloadJobRemoteSource)
if subjob.status == DownloadJobStatus.RUNNING:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
if subjob.status == DownloadJobStatus.ERROR:
job.error = subjob.error
job.cleanup()
self._update_job_status(job, DownloadJobStatus.ERROR)
return
if subjob.status == DownloadJobStatus.COMPLETED:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
subqueue = self.__class__(
event_handlers=[subdownload_event],
requests_session=self._requests,
quiet=True,
)
assert isinstance(job, DownloadJobRepoID)
try:
repo_id = job.source
variant = job.variant
if not job.metadata:
job.metadata = ModelSourceMetadata()
urls_to_download = self._get_repo_info(
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
)
if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name
bytes_downloaded: Dict[int, int] = dict()
job.total_bytes = 0
for url, subdir, file, size in urls_to_download:
job.total_bytes += size
subqueue.create_download_job(
source=url,
destdir=job.destination / subdir,
filename=file,
variant=variant,
access_token=job.access_token,
)
except KeyboardInterrupt as excp:
raise excp
except Exception as excp:
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
self._logger.error(job.error)
finally:
job.subqueue = subqueue
job.subqueue.join()
if job.status == DownloadJobStatus.RUNNING:
self._update_job_status(job, DownloadJobStatus.COMPLETED)
job.subqueue.release() # get rid of the subqueue
def _get_repo_info(
self,
repo_id: str,
metadata: ModelSourceMetadata,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
"""
Given a repo_id and an optional variant, return list of URLs to download to get the model.
The metadata field will be updated with model metadata from HuggingFace.
Known variants currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
paths = [x.rfilename for x in sibs]
sizes = {x.rfilename: x.size for x in sibs}
prefix = ""
if subfolder:
prefix = f"{subfolder}/"
paths = [x for x in paths if x.startswith(prefix)]
if f"{prefix}model_index.json" in paths:
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
resp = self._requests.get(url)
resp.raise_for_status() # will raise an HTTPError on non-200 status
submodels = resp.json()
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, f"{prefix}model_index.json")
urls = [
(
hf_hub_url(repo_id, filename=x.as_posix()),
x.parent.relative_to(prefix) or Path("."),
Path(x.name),
sizes[x.as_posix()],
)
for x in self._select_variants(paths, variant)
]
if hasattr(model_info, "cardData"):
metadata.license = metadata.license or model_info.cardData.get("license")
metadata.tags = metadata.tags or model_info.tags
metadata.author = metadata.author or model_info.author
return urls
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames: Dict[Path, Path] = dict()
for p in paths:
path = Path(p)
if path.suffix == ".onnx":
if variant == "onnx":
result.add(path)
elif path.name.startswith("openvino_model"):
if variant == "openvino":
result.add(path)
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
file_variant, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
file_variant = None
suffix = suffixes[0]
basename = parent / path.stem
if previous := basenames.get(basename):
if previous.suffix != ".safetensors" and suffix == ".safetensors":
basenames[basename] = path
if file_variant == f".{variant}":
basenames[basename] = path
elif not variant and not file_variant:
basenames[basename] = path
else:
basenames[basename] = path
else:
continue
for v in basenames.values():
result.add(v)
return result

View File

@ -9,7 +9,7 @@ import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from queue import PriorityQueue from queue import PriorityQueue
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import requests import requests
from huggingface_hub import HfApi, hf_hub_url from huggingface_hub import HfApi, hf_hub_url
@ -21,14 +21,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import InvokeAILogger, Logger from invokeai.backend.util import InvokeAILogger, Logger
from ..storage import DuplicateModelException from ..storage import DuplicateModelException
from .base import ( from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase, UnknownJobIDException
DownloadEventHandler,
DownloadJobBase,
DownloadJobStatus,
DownloadQueueBase,
ModelSourceMetadata,
UnknownJobIDException,
)
# Maximum number of bytes to download during each call to requests.iter_content() # Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000 DOWNLOAD_CHUNK_SIZE = 100000
@ -36,17 +29,8 @@ DOWNLOAD_CHUNK_SIZE = 100000
# marker that the queue is done and that thread should exit # marker that the queue is done and that thread should exit
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/") STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
# endpoint for civitai get-model API # regular expression for picking up a URL
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
# Regular expressions to describe repo_ids and http urls
HTTP_RE = r"^https?://" HTTP_RE = r"^https?://"
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
class DownloadJobPath(DownloadJobBase): class DownloadJobPath(DownloadJobBase):
@ -69,32 +53,6 @@ class DownloadJobURL(DownloadJobRemoteSource):
source: AnyHttpUrl = Field(description="URL to download") source: AnyHttpUrl = Field(description="URL to download")
class DownloadJobRepoID(DownloadJobRemoteSource):
"""Download repo ids."""
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
subfolder: Optional[str] = Field(
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
)
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
subqueue: Optional["DownloadQueueBase"] = Field(
description="a subqueue used for downloading the individual files in the repo_id", default=None
)
@validator("source")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def cleanup(self, preserve_partial_downloads: bool = False):
"""Perform action when job is completed."""
if self.subqueue:
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
self.subqueue.release()
class DownloadQueue(DownloadQueueBase): class DownloadQueue(DownloadQueueBase):
"""Class for queued download of models.""" """Class for queued download of models."""
@ -145,7 +103,7 @@ class DownloadQueue(DownloadQueueBase):
filename: Optional[Path] = None, filename: Optional[Path] = None,
variant: Optional[str] = None, variant: Optional[str] = None,
access_token: Optional[str] = None, access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None, event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase: ) -> DownloadJobBase:
"""Create a download job and return its ID.""" """Create a download job and return its ID."""
kwargs: Dict[str, Optional[str]] = dict() kwargs: Dict[str, Optional[str]] = dict()
@ -153,12 +111,6 @@ class DownloadQueue(DownloadQueueBase):
cls = DownloadJobBase cls = DownloadJobBase
if Path(source).exists(): if Path(source).exists():
cls = DownloadJobPath cls = DownloadJobPath
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = DownloadJobRepoID
kwargs.update(
variant=variant,
access_token=access_token,
)
elif re.match(HTTP_RE, str(source)): elif re.match(HTTP_RE, str(source)):
cls = DownloadJobURL cls = DownloadJobURL
kwargs.update(access_token=access_token) kwargs.update(access_token=access_token)
@ -168,7 +120,7 @@ class DownloadQueue(DownloadQueueBase):
job = cls( job = cls(
source=source, source=source,
destination=Path(destdir) / (filename or "."), destination=Path(destdir) / (filename or "."),
event_handlers=(event_handlers or self._event_handlers), event_handlers=event_handlers,
priority=priority, priority=priority,
**kwargs, **kwargs,
) )
@ -349,89 +301,32 @@ class DownloadQueue(DownloadQueueBase):
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen) if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
if not self._quiet: if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}") self._logger.info(f"{job.source}: Downloading to {job.destination}")
if isinstance(job, DownloadJobURL): do_download = self.select_downloader(job)
self._download_with_resume(job) do_download(job)
elif isinstance(job, DownloadJobRepoID):
self._download_repoid(job)
elif isinstance(job, DownloadJobPath):
self._download_path(job)
else:
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
if job.status == DownloadJobStatus.CANCELLED: if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job) self._cleanup_cancelled_job(job)
self._queue.task_done() self._queue.task_done()
def _get_metadata_and_url(self, job: DownloadJobRemoteSource) -> AnyHttpUrl: def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
""" """Based on the job type select the download method."""
Fetch metadata from certain well-known URLs. if isinstance(job, DownloadJobURL):
return self._download_with_resume
elif isinstance(job, DownloadJobPath):
return self._download_path
else:
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
The metadata will be stashed in job.metadata, if found def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
Return the download URL. return job.source
"""
metadata = job.metadata
url = job.source
metadata_url = url
model = None
# a Civitai download URL
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
# a Civitai model page with the version
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
model = match.group(1)
version = int(match.group(2))
# and without
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
model = match.group(1)
version = None
if not model:
return parse_obj_as(AnyHttpUrl, url)
if model:
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
metadata.license = (
metadata.license
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
)
if version:
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
version_data = versions[0]
else:
version_data = resp["modelVersions"][0] # first one
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
if version_data.get("trainedWords")
else version_data.get("description")
)
download_url = version_data["downloadUrl"]
# return the download url
return download_url
def _download_with_resume(self, job: DownloadJobBase): def _download_with_resume(self, job: DownloadJobBase):
"""Do the actual download.""" """Do the actual download."""
dest = None dest = None
try: try:
assert isinstance(job, DownloadJobRemoteSource) assert isinstance(job, DownloadJobRemoteSource)
url = self._get_metadata_and_url(job) url = self.get_url_for_job(job)
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb" open_mode = "wb"
exist_size = 0 exist_size = 0
@ -497,7 +392,8 @@ class DownloadQueue(DownloadQueueBase):
if job.bytes - last_report_bytes >= report_delta: if job.bytes - last_report_bytes >= report_delta:
last_report_bytes = job.bytes last_report_bytes = job.bytes
self._update_job_status(job) self._update_job_status(job)
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
return
self._update_job_status(job, DownloadJobStatus.COMPLETED) self._update_job_status(job, DownloadJobStatus.COMPLETED)
except KeyboardInterrupt as excp: except KeyboardInterrupt as excp:
raise excp raise excp
@ -541,166 +437,6 @@ class DownloadQueue(DownloadQueueBase):
job.error = excp job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR) self._update_job_status(job, DownloadJobStatus.ERROR)
def _download_repoid(self, job: DownloadJobRepoID):
"""Download a job that holds a huggingface repoid."""
def subdownload_event(subjob: DownloadJobBase):
assert isinstance(subjob, DownloadJobRemoteSource)
if subjob.status == DownloadJobStatus.RUNNING:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
if subjob.status == DownloadJobStatus.ERROR:
job.error = subjob.error
job.cleanup()
self._update_job_status(job, DownloadJobStatus.ERROR)
return
if subjob.status == DownloadJobStatus.COMPLETED:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
subqueue = self.__class__(
event_handlers=[subdownload_event],
requests_session=self._requests,
quiet=True,
)
try:
assert isinstance(job, DownloadJobRepoID)
repo_id = job.source
variant = job.variant
if not job.metadata:
job.metadata = ModelSourceMetadata()
urls_to_download = self._get_repo_info(
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
)
if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name
bytes_downloaded: Dict[int, int] = dict()
job.total_bytes = 0
for url, subdir, file, size in urls_to_download:
job.total_bytes += size
subqueue.create_download_job(
source=url,
destdir=job.destination / subdir,
filename=file,
variant=variant,
access_token=job.access_token,
)
except KeyboardInterrupt as excp:
raise excp
except Exception as excp:
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
self._logger.error(job.error)
finally:
job.subqueue = subqueue
job.subqueue.join()
if job.status == DownloadJobStatus.RUNNING:
self._update_job_status(job, DownloadJobStatus.COMPLETED)
job.subqueue.release() # get rid of the subqueue
def _get_repo_info(
self,
repo_id: str,
metadata: ModelSourceMetadata,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
"""
Given a repo_id and an optional variant, return list of URLs to download to get the model.
The metadata field will be updated with model metadata from HuggingFace.
Known variants currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
paths = [x.rfilename for x in sibs]
sizes = {x.rfilename: x.size for x in sibs}
prefix = ""
if subfolder:
prefix = f"{subfolder}/"
paths = [x for x in paths if x.startswith(prefix)]
if f"{prefix}model_index.json" in paths:
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
resp = self._requests.get(url)
resp.raise_for_status() # will raise an HTTPError on non-200 status
submodels = resp.json()
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, f"{prefix}model_index.json")
urls = [
(
hf_hub_url(repo_id, filename=x.as_posix()),
x.parent.relative_to(prefix) or Path("."),
Path(x.name),
sizes[x.as_posix()],
)
for x in self._select_variants(paths, variant)
]
if hasattr(model_info, "cardData"):
metadata.license = metadata.license or model_info.cardData.get("license")
metadata.tags = metadata.tags or model_info.tags
metadata.author = metadata.author or model_info.author
return urls
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames: Dict[Path, Path] = dict()
for p in paths:
path = Path(p)
if path.suffix == ".onnx":
if variant == "onnx":
result.add(path)
elif path.name.startswith("openvino_model"):
if variant == "openvino":
result.add(path)
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
file_variant, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
file_variant = None
suffix = suffixes[0]
basename = parent / path.stem
if previous := basenames.get(basename):
if previous.suffix != ".safetensors" and suffix == ".safetensors":
basenames[basename] = path
if file_variant == f".{variant}":
basenames[basename] = path
elif not variant and not file_variant:
basenames[basename] = path
else:
basenames[basename] = path
else:
continue
for v in basenames.values():
result.add(v)
return result
def _download_path(self, job: DownloadJobBase): def _download_path(self, job: DownloadJobBase):
"""Call when the source is a Path or pathlike object.""" """Call when the source is a Path or pathlike object."""
source = Path(job.source).resolve() source = Path(job.source).resolve()

View File

@ -72,14 +72,20 @@ from .config import (
SchedulerPredictionType, SchedulerPredictionType,
SubModelType, SubModelType,
) )
from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata from .download import (
from .download.queue import ( DownloadEventHandler,
DownloadJobBase,
DownloadJobPath,
DownloadJobURL,
DownloadQueueBase,
ModelDownloadQueue,
ModelSourceMetadata,
)
from .download.model_queue import (
HTTP_RE, HTTP_RE,
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
DownloadJobRemoteSource,
DownloadJobPath,
DownloadJobRepoID, DownloadJobRepoID,
DownloadJobURL, DownloadJobWithMetadata,
) )
from .hash import FastModelHash from .hash import FastModelHash
from .models import InvalidModelException from .models import InvalidModelException
@ -88,7 +94,7 @@ from .search import ModelSearch
from .storage import DuplicateModelException, ModelConfigStore from .storage import DuplicateModelException, ModelConfigStore
class ModelInstallJob(DownloadJobRemoteSource): class ModelInstallJob(DownloadJobBase):
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info.""" """This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
model_key: Optional[str] = Field( model_key: Optional[str] = Field(
@ -100,7 +106,7 @@ class ModelInstallJob(DownloadJobRemoteSource):
) )
class ModelInstallURLJob(DownloadJobURL, ModelInstallJob): class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob):
"""Job for installing URLs.""" """Job for installing URLs."""
@ -398,7 +404,7 @@ class ModelInstall(ModelInstallBase):
self._app_config = config or InvokeAIAppConfig.get_config() self._app_config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.get_logger(config=self._app_config) self._logger = logger or InvokeAILogger.get_logger(config=self._app_config)
self._store = store or ModelRecordServiceBase.get_impl(self._app_config) self._store = store or ModelRecordServiceBase.get_impl(self._app_config)
self._download_queue = download or DownloadQueue(config=self._app_config, event_handlers=event_handlers) self._download_queue = download or ModelDownloadQueue(config=self._app_config, event_handlers=event_handlers)
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict() self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
self._installed = set() self._installed = set()
self._tmpdir = None self._tmpdir = None

View File

@ -8,11 +8,11 @@ import requests
from requests import HTTPError from requests import HTTPError
from requests_testadapter import TestAdapter from requests_testadapter import TestAdapter
import invokeai.backend.model_manager.download.queue as download_queue import invokeai.backend.model_manager.download.model_queue as download_queue
from invokeai.backend.model_manager.download import ( from invokeai.backend.model_manager.download import (
DownloadJobBase, DownloadJobBase,
DownloadJobStatus, DownloadJobStatus,
DownloadQueue, ModelDownloadQueue,
UnknownJobIDException, UnknownJobIDException,
) )
@ -147,7 +147,7 @@ def test_basic_queue_download():
def event_handler(job: DownloadJobBase): def event_handler(job: DownloadJobBase):
events.append(job.status) events.append(job.status)
queue = DownloadQueue( queue = ModelDownloadQueue(
requests_session=session, requests_session=session,
event_handlers=[event_handler], event_handlers=[event_handler],
) )
@ -167,7 +167,7 @@ def test_basic_queue_download():
def test_queue_priority(): def test_queue_priority():
queue = DownloadQueue( queue = ModelDownloadQueue(
requests_session=session, requests_session=session,
) )
@ -200,7 +200,7 @@ def test_repo_id_download():
if not INTERNET_AVAILABLE: if not INTERNET_AVAILABLE:
return return
repo_id = "stabilityai/stable-diffusion-2-1" repo_id = "stabilityai/stable-diffusion-2-1"
queue = DownloadQueue( queue = ModelDownloadQueue(
requests_session=session, requests_session=session,
) )
@ -224,7 +224,7 @@ def test_repo_id_download():
def test_bad_urls(): def test_bad_urls():
queue = DownloadQueue( queue = ModelDownloadQueue(
requests_session=session, requests_session=session,
) )
@ -270,9 +270,11 @@ def test_bad_urls():
def test_pause_cancel_url(): # this one is tricky because of potential race conditions def test_pause_cancel_url(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase): def event_handler(job: DownloadJobBase):
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step if job.id == 0:
print(job.status, job.bytes)
time.sleep(0.5) # slow down the thread so that we can recover the paused state
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler]) queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False) job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False) job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
@ -322,7 +324,7 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
return return
repo_id = "stabilityai/stable-diffusion-2-1" repo_id = "stabilityai/stable-diffusion-2-1"
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler]) queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2: with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False) job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)

View File

@ -55,7 +55,8 @@ def test_install(datadir: Path):
mm_install = ModelInstallService(config=config, store=mm_store, event_bus=event_bus) mm_install = ModelInstallService(config=config, store=mm_store, event_bus=event_bus)
source = datadir / TEST_MODEL source = datadir / TEST_MODEL
mm_install.install_model(source=source) job = mm_install.install_model(source=source)
print(f"DEBUG: job={type(job)}")
id_map = mm_install.wait_for_installs() id_map = mm_install.wait_for_installs()
print(id_map) print(id_map)
assert source in id_map, "model did not install; id_map empty" assert source in id_map, "model did not install; id_map empty"