refactor download queue jobs

This commit is contained in:
Lincoln Stein
2023-10-08 16:39:23 -04:00
parent a64a34b49a
commit e5b2bc8532
7 changed files with 413 additions and 319 deletions

View File

@ -5,7 +5,7 @@ from .base import ( # noqa F401
DownloadJobBase,
DownloadJobStatus,
DownloadQueueBase,
ModelSourceMetadata,
UnknownJobIDException,
)
from .queue import DownloadQueue # noqa F401
from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue # noqa F401

View File

@ -30,17 +30,6 @@ class UnknownJobIDException(Exception):
"""Raised when an invalid Job is referenced."""
class ModelSourceMetadata(BaseModel):
"""Information collected on a downloadable model from its source site."""
name: Optional[str] = Field(description="Human-readable name of this model")
author: Optional[str] = Field(description="Author/creator of the model")
description: Optional[str] = Field(description="Description of the model")
license: Optional[str] = Field(description="Model license terms")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
tags: Optional[List[str]] = Field(description="List of descriptive tags")
DownloadEventHandler = Callable[["DownloadJobBase"], None]
@ -67,9 +56,6 @@ class DownloadJobBase(BaseModel):
description="if true, then preserve partial downloads when cancelled or errored", default=False
)
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
metadata: ModelSourceMetadata = Field(
description="Metadata describing download contents", default_factory=ModelSourceMetadata
)
def add_event_handler(self, handler: DownloadEventHandler):
"""Add an event handler to the end of the handlers list."""
@ -134,7 +120,7 @@ class DownloadQueueBase(ABC):
filename: Optional[Path] = None,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase:
"""
Create and submit a download job.
@ -274,3 +260,17 @@ class DownloadQueueBase(ABC):
no longer recognize the job.
"""
pass
@abstractmethod
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method."""
pass
@abstractmethod
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
"""
Given a job, translate its source field into a downloadable URL.
Intended to be subclassed to cover various source types.
"""
pass

View File

@ -0,0 +1,349 @@
import re
from pathlib import Path
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
from huggingface_hub import HfApi, hf_hub_url
from pydantic import BaseModel, Field, parse_obj_as, validator
from pydantic.networks import AnyHttpUrl
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase
from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadQueue
# regular expressions used to dispatch appropriate downloaders and metadata retrievers
# endpoint for civitai get-model API
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
# Regular expressions to describe repo_ids and http urls
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
class ModelSourceMetadata(BaseModel):
"""Information collected on a downloadable model from its source site."""
name: Optional[str] = Field(description="Human-readable name of this model")
author: Optional[str] = Field(description="Author/creator of the model")
description: Optional[str] = Field(description="Description of the model")
license: Optional[str] = Field(description="Model license terms")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
tags: Optional[List[str]] = Field(description="List of descriptive tags")
class DownloadJobWithMetadata(DownloadJobRemoteSource):
"""A remote download that has metadata associated with it."""
metadata: ModelSourceMetadata = Field(
description="Metadata describing the model, derived from source", default_factory=ModelSourceMetadata
)
class DownloadJobRepoID(DownloadJobWithMetadata):
"""Download repo ids."""
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
subfolder: Optional[str] = Field(
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
)
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
subqueue: Optional[DownloadQueueBase] = Field(
description="a subqueue used for downloading the individual files in the repo_id", default=None
)
@validator("source")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def cleanup(self, preserve_partial_downloads: bool = False):
"""Perform action when job is completed."""
if self.subqueue:
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
self.subqueue.release()
class ModelDownloadQueue(DownloadQueue):
"""Subclass of DownloadQueue, able to retrieve metadata from HuggingFace and Civitai."""
def create_download_job(
self,
source: Union[str, Path, AnyHttpUrl],
destdir: Path,
start: bool = True,
priority: int = 10,
filename: Optional[Path] = None,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase:
"""Create a download job and return its ID."""
cls: Optional[Type[DownloadJobBase]] = None
kwargs: Dict[str, Optional[str]] = dict()
if re.match(HTTP_RE, str(source)):
cls = DownloadJobWithMetadata
kwargs.update(access_token=access_token)
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = DownloadJobRepoID
kwargs.update(
variant=variant,
access_token=access_token,
)
if cls:
job = cls(
source=source,
destination=Path(destdir) / (filename or "."),
event_handlers=event_handlers,
priority=priority,
**kwargs,
)
return self.submit_download_job(job, start)
else:
return super().create_download_job(
source=source,
destdir=destdir,
start=start,
priority=priority,
filename=filename,
variant=variant,
access_token=access_token,
event_handlers=event_handlers,
)
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method."""
if isinstance(job, DownloadJobRepoID):
return self._download_repoid
elif isinstance(job, DownloadJobWithMetadata):
return self._download_with_resume
else:
return super().select_downloader(job)
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
"""
Fetch metadata from certain well-known URLs.
The metadata will be stashed in job.metadata, if found
Return the download URL.
"""
assert isinstance(job, DownloadJobWithMetadata)
metadata = job.metadata
url = job.source
metadata_url = url
model = None
# a Civitai download URL
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
# a Civitai model page with the version
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
model = match.group(1)
version = int(match.group(2))
# and without
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
model = match.group(1)
version = None
if not model:
return parse_obj_as(AnyHttpUrl, url)
if model:
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
metadata.license = (
metadata.license
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
)
if version:
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
version_data = versions[0]
else:
version_data = resp["modelVersions"][0] # first one
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
if version_data.get("trainedWords")
else version_data.get("description")
)
download_url = version_data["downloadUrl"]
# return the download url
return download_url
def _download_repoid(self, job: DownloadJobBase) -> None:
"""Download a job that holds a huggingface repoid."""
def subdownload_event(subjob: DownloadJobBase):
assert isinstance(subjob, DownloadJobRemoteSource)
assert isinstance(job, DownloadJobRemoteSource)
if subjob.status == DownloadJobStatus.RUNNING:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
if subjob.status == DownloadJobStatus.ERROR:
job.error = subjob.error
job.cleanup()
self._update_job_status(job, DownloadJobStatus.ERROR)
return
if subjob.status == DownloadJobStatus.COMPLETED:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
subqueue = self.__class__(
event_handlers=[subdownload_event],
requests_session=self._requests,
quiet=True,
)
assert isinstance(job, DownloadJobRepoID)
try:
repo_id = job.source
variant = job.variant
if not job.metadata:
job.metadata = ModelSourceMetadata()
urls_to_download = self._get_repo_info(
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
)
if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name
bytes_downloaded: Dict[int, int] = dict()
job.total_bytes = 0
for url, subdir, file, size in urls_to_download:
job.total_bytes += size
subqueue.create_download_job(
source=url,
destdir=job.destination / subdir,
filename=file,
variant=variant,
access_token=job.access_token,
)
except KeyboardInterrupt as excp:
raise excp
except Exception as excp:
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
self._logger.error(job.error)
finally:
job.subqueue = subqueue
job.subqueue.join()
if job.status == DownloadJobStatus.RUNNING:
self._update_job_status(job, DownloadJobStatus.COMPLETED)
job.subqueue.release() # get rid of the subqueue
def _get_repo_info(
self,
repo_id: str,
metadata: ModelSourceMetadata,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
"""
Given a repo_id and an optional variant, return list of URLs to download to get the model.
The metadata field will be updated with model metadata from HuggingFace.
Known variants currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
paths = [x.rfilename for x in sibs]
sizes = {x.rfilename: x.size for x in sibs}
prefix = ""
if subfolder:
prefix = f"{subfolder}/"
paths = [x for x in paths if x.startswith(prefix)]
if f"{prefix}model_index.json" in paths:
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
resp = self._requests.get(url)
resp.raise_for_status() # will raise an HTTPError on non-200 status
submodels = resp.json()
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, f"{prefix}model_index.json")
urls = [
(
hf_hub_url(repo_id, filename=x.as_posix()),
x.parent.relative_to(prefix) or Path("."),
Path(x.name),
sizes[x.as_posix()],
)
for x in self._select_variants(paths, variant)
]
if hasattr(model_info, "cardData"):
metadata.license = metadata.license or model_info.cardData.get("license")
metadata.tags = metadata.tags or model_info.tags
metadata.author = metadata.author or model_info.author
return urls
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames: Dict[Path, Path] = dict()
for p in paths:
path = Path(p)
if path.suffix == ".onnx":
if variant == "onnx":
result.add(path)
elif path.name.startswith("openvino_model"):
if variant == "openvino":
result.add(path)
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
file_variant, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
file_variant = None
suffix = suffixes[0]
basename = parent / path.stem
if previous := basenames.get(basename):
if previous.suffix != ".safetensors" and suffix == ".safetensors":
basenames[basename] = path
if file_variant == f".{variant}":
basenames[basename] = path
elif not variant and not file_variant:
basenames[basename] = path
else:
basenames[basename] = path
else:
continue
for v in basenames.values():
result.add(v)
return result

View File

@ -9,7 +9,7 @@ import time
import traceback
from pathlib import Path
from queue import PriorityQueue
from typing import Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
import requests
from huggingface_hub import HfApi, hf_hub_url
@ -21,14 +21,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import InvokeAILogger, Logger
from ..storage import DuplicateModelException
from .base import (
DownloadEventHandler,
DownloadJobBase,
DownloadJobStatus,
DownloadQueueBase,
ModelSourceMetadata,
UnknownJobIDException,
)
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase, UnknownJobIDException
# Maximum number of bytes to download during each call to requests.iter_content()
DOWNLOAD_CHUNK_SIZE = 100000
@ -36,17 +29,8 @@ DOWNLOAD_CHUNK_SIZE = 100000
# marker that the queue is done and that thread should exit
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
# endpoint for civitai get-model API
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
# Regular expressions to describe repo_ids and http urls
# regular expression for picking up a URL
HTTP_RE = r"^https?://"
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
class DownloadJobPath(DownloadJobBase):
@ -69,32 +53,6 @@ class DownloadJobURL(DownloadJobRemoteSource):
source: AnyHttpUrl = Field(description="URL to download")
class DownloadJobRepoID(DownloadJobRemoteSource):
"""Download repo ids."""
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
subfolder: Optional[str] = Field(
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
)
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
subqueue: Optional["DownloadQueueBase"] = Field(
description="a subqueue used for downloading the individual files in the repo_id", default=None
)
@validator("source")
@classmethod
def proper_repo_id(cls, v: str) -> str: # noqa D102
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
raise ValueError(f"{v}: invalid repo_id format")
return v
def cleanup(self, preserve_partial_downloads: bool = False):
"""Perform action when job is completed."""
if self.subqueue:
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
self.subqueue.release()
class DownloadQueue(DownloadQueueBase):
"""Class for queued download of models."""
@ -145,7 +103,7 @@ class DownloadQueue(DownloadQueueBase):
filename: Optional[Path] = None,
variant: Optional[str] = None,
access_token: Optional[str] = None,
event_handlers: Optional[List[DownloadEventHandler]] = None,
event_handlers: List[DownloadEventHandler] = [],
) -> DownloadJobBase:
"""Create a download job and return its ID."""
kwargs: Dict[str, Optional[str]] = dict()
@ -153,12 +111,6 @@ class DownloadQueue(DownloadQueueBase):
cls = DownloadJobBase
if Path(source).exists():
cls = DownloadJobPath
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
cls = DownloadJobRepoID
kwargs.update(
variant=variant,
access_token=access_token,
)
elif re.match(HTTP_RE, str(source)):
cls = DownloadJobURL
kwargs.update(access_token=access_token)
@ -168,7 +120,7 @@ class DownloadQueue(DownloadQueueBase):
job = cls(
source=source,
destination=Path(destdir) / (filename or "."),
event_handlers=(event_handlers or self._event_handlers),
event_handlers=event_handlers,
priority=priority,
**kwargs,
)
@ -349,89 +301,32 @@ class DownloadQueue(DownloadQueueBase):
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
if not self._quiet:
self._logger.info(f"{job.source}: Downloading to {job.destination}")
if isinstance(job, DownloadJobURL):
self._download_with_resume(job)
elif isinstance(job, DownloadJobRepoID):
self._download_repoid(job)
elif isinstance(job, DownloadJobPath):
self._download_path(job)
else:
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
do_download = self.select_downloader(job)
do_download(job)
if job.status == DownloadJobStatus.CANCELLED:
self._cleanup_cancelled_job(job)
self._queue.task_done()
def _get_metadata_and_url(self, job: DownloadJobRemoteSource) -> AnyHttpUrl:
"""
Fetch metadata from certain well-known URLs.
The metadata will be stashed in job.metadata, if found
Return the download URL.
"""
metadata = job.metadata
url = job.source
metadata_url = url
model = None
# a Civitai download URL
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
# a Civitai model page with the version
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
model = match.group(1)
version = int(match.group(2))
# and without
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
model = match.group(1)
version = None
if not model:
return parse_obj_as(AnyHttpUrl, url)
if model:
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
metadata.license = (
metadata.license
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
)
if version:
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
version_data = versions[0]
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
"""Based on the job type select the download method."""
if isinstance(job, DownloadJobURL):
return self._download_with_resume
elif isinstance(job, DownloadJobPath):
return self._download_path
else:
version_data = resp["modelVersions"][0] # first one
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
if version_data.get("trainedWords")
else version_data.get("description")
)
download_url = version_data["downloadUrl"]
# return the download url
return download_url
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
return job.source
def _download_with_resume(self, job: DownloadJobBase):
"""Do the actual download."""
dest = None
try:
assert isinstance(job, DownloadJobRemoteSource)
url = self._get_metadata_and_url(job)
url = self.get_url_for_job(job)
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
exist_size = 0
@ -497,7 +392,8 @@ class DownloadQueue(DownloadQueueBase):
if job.bytes - last_report_bytes >= report_delta:
last_report_bytes = job.bytes
self._update_job_status(job)
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
return
self._update_job_status(job, DownloadJobStatus.COMPLETED)
except KeyboardInterrupt as excp:
raise excp
@ -541,166 +437,6 @@ class DownloadQueue(DownloadQueueBase):
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
def _download_repoid(self, job: DownloadJobRepoID):
"""Download a job that holds a huggingface repoid."""
def subdownload_event(subjob: DownloadJobBase):
assert isinstance(subjob, DownloadJobRemoteSource)
if subjob.status == DownloadJobStatus.RUNNING:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
if subjob.status == DownloadJobStatus.ERROR:
job.error = subjob.error
job.cleanup()
self._update_job_status(job, DownloadJobStatus.ERROR)
return
if subjob.status == DownloadJobStatus.COMPLETED:
bytes_downloaded[subjob.id] = subjob.bytes
job.bytes = sum(bytes_downloaded.values())
self._update_job_status(job, DownloadJobStatus.RUNNING)
return
subqueue = self.__class__(
event_handlers=[subdownload_event],
requests_session=self._requests,
quiet=True,
)
try:
assert isinstance(job, DownloadJobRepoID)
repo_id = job.source
variant = job.variant
if not job.metadata:
job.metadata = ModelSourceMetadata()
urls_to_download = self._get_repo_info(
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
)
if job.destination.name != Path(repo_id).name:
job.destination = job.destination / Path(repo_id).name
bytes_downloaded: Dict[int, int] = dict()
job.total_bytes = 0
for url, subdir, file, size in urls_to_download:
job.total_bytes += size
subqueue.create_download_job(
source=url,
destdir=job.destination / subdir,
filename=file,
variant=variant,
access_token=job.access_token,
)
except KeyboardInterrupt as excp:
raise excp
except Exception as excp:
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
self._logger.error(job.error)
finally:
job.subqueue = subqueue
job.subqueue.join()
if job.status == DownloadJobStatus.RUNNING:
self._update_job_status(job, DownloadJobStatus.COMPLETED)
job.subqueue.release() # get rid of the subqueue
def _get_repo_info(
self,
repo_id: str,
metadata: ModelSourceMetadata,
variant: Optional[str] = None,
subfolder: Optional[str] = None,
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
"""
Given a repo_id and an optional variant, return list of URLs to download to get the model.
The metadata field will be updated with model metadata from HuggingFace.
Known variants currently are:
1. onnx
2. openvino
3. fp16
4. None (usually returns fp32 model)
"""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
paths = [x.rfilename for x in sibs]
sizes = {x.rfilename: x.size for x in sibs}
prefix = ""
if subfolder:
prefix = f"{subfolder}/"
paths = [x for x in paths if x.startswith(prefix)]
if f"{prefix}model_index.json" in paths:
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
resp = self._requests.get(url)
resp.raise_for_status() # will raise an HTTPError on non-200 status
submodels = resp.json()
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
paths.insert(0, f"{prefix}model_index.json")
urls = [
(
hf_hub_url(repo_id, filename=x.as_posix()),
x.parent.relative_to(prefix) or Path("."),
Path(x.name),
sizes[x.as_posix()],
)
for x in self._select_variants(paths, variant)
]
if hasattr(model_info, "cardData"):
metadata.license = metadata.license or model_info.cardData.get("license")
metadata.tags = metadata.tags or model_info.tags
metadata.author = metadata.author or model_info.author
return urls
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
result = set()
basenames: Dict[Path, Path] = dict()
for p in paths:
path = Path(p)
if path.suffix == ".onnx":
if variant == "onnx":
result.add(path)
elif path.name.startswith("openvino_model"):
if variant == "openvino":
result.add(path)
elif path.suffix in [".json", ".txt"]:
result.add(path)
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
parent = path.parent
suffixes = path.suffixes
if len(suffixes) == 2:
file_variant, suffix = suffixes
basename = parent / Path(path.stem).stem
else:
file_variant = None
suffix = suffixes[0]
basename = parent / path.stem
if previous := basenames.get(basename):
if previous.suffix != ".safetensors" and suffix == ".safetensors":
basenames[basename] = path
if file_variant == f".{variant}":
basenames[basename] = path
elif not variant and not file_variant:
basenames[basename] = path
else:
basenames[basename] = path
else:
continue
for v in basenames.values():
result.add(v)
return result
def _download_path(self, job: DownloadJobBase):
"""Call when the source is a Path or pathlike object."""
source = Path(job.source).resolve()

View File

@ -72,14 +72,20 @@ from .config import (
SchedulerPredictionType,
SubModelType,
)
from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata
from .download.queue import (
from .download import (
DownloadEventHandler,
DownloadJobBase,
DownloadJobPath,
DownloadJobURL,
DownloadQueueBase,
ModelDownloadQueue,
ModelSourceMetadata,
)
from .download.model_queue import (
HTTP_RE,
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
DownloadJobRemoteSource,
DownloadJobPath,
DownloadJobRepoID,
DownloadJobURL,
DownloadJobWithMetadata,
)
from .hash import FastModelHash
from .models import InvalidModelException
@ -88,7 +94,7 @@ from .search import ModelSearch
from .storage import DuplicateModelException, ModelConfigStore
class ModelInstallJob(DownloadJobRemoteSource):
class ModelInstallJob(DownloadJobBase):
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
model_key: Optional[str] = Field(
@ -100,7 +106,7 @@ class ModelInstallJob(DownloadJobRemoteSource):
)
class ModelInstallURLJob(DownloadJobURL, ModelInstallJob):
class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob):
"""Job for installing URLs."""
@ -398,7 +404,7 @@ class ModelInstall(ModelInstallBase):
self._app_config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.get_logger(config=self._app_config)
self._store = store or ModelRecordServiceBase.get_impl(self._app_config)
self._download_queue = download or DownloadQueue(config=self._app_config, event_handlers=event_handlers)
self._download_queue = download or ModelDownloadQueue(config=self._app_config, event_handlers=event_handlers)
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
self._installed = set()
self._tmpdir = None

View File

@ -8,11 +8,11 @@ import requests
from requests import HTTPError
from requests_testadapter import TestAdapter
import invokeai.backend.model_manager.download.queue as download_queue
import invokeai.backend.model_manager.download.model_queue as download_queue
from invokeai.backend.model_manager.download import (
DownloadJobBase,
DownloadJobStatus,
DownloadQueue,
ModelDownloadQueue,
UnknownJobIDException,
)
@ -147,7 +147,7 @@ def test_basic_queue_download():
def event_handler(job: DownloadJobBase):
events.append(job.status)
queue = DownloadQueue(
queue = ModelDownloadQueue(
requests_session=session,
event_handlers=[event_handler],
)
@ -167,7 +167,7 @@ def test_basic_queue_download():
def test_queue_priority():
queue = DownloadQueue(
queue = ModelDownloadQueue(
requests_session=session,
)
@ -200,7 +200,7 @@ def test_repo_id_download():
if not INTERNET_AVAILABLE:
return
repo_id = "stabilityai/stable-diffusion-2-1"
queue = DownloadQueue(
queue = ModelDownloadQueue(
requests_session=session,
)
@ -224,7 +224,7 @@ def test_repo_id_download():
def test_bad_urls():
queue = DownloadQueue(
queue = ModelDownloadQueue(
requests_session=session,
)
@ -270,9 +270,11 @@ def test_bad_urls():
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
def event_handler(job: DownloadJobBase):
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
if job.id == 0:
print(job.status, job.bytes)
time.sleep(0.5) # slow down the thread so that we can recover the paused state
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
with tempfile.TemporaryDirectory() as tmpdir:
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
@ -322,7 +324,7 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
return
repo_id = "stabilityai/stable-diffusion-2-1"
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)

View File

@ -55,7 +55,8 @@ def test_install(datadir: Path):
mm_install = ModelInstallService(config=config, store=mm_store, event_bus=event_bus)
source = datadir / TEST_MODEL
mm_install.install_model(source=source)
job = mm_install.install_model(source=source)
print(f"DEBUG: job={type(job)}")
id_map = mm_install.wait_for_installs()
print(id_map)
assert source in id_map, "model did not install; id_map empty"