mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor download queue jobs
This commit is contained in:
@ -5,7 +5,7 @@ from .base import ( # noqa F401
|
|||||||
DownloadJobBase,
|
DownloadJobBase,
|
||||||
DownloadJobStatus,
|
DownloadJobStatus,
|
||||||
DownloadQueueBase,
|
DownloadQueueBase,
|
||||||
ModelSourceMetadata,
|
|
||||||
UnknownJobIDException,
|
UnknownJobIDException,
|
||||||
)
|
)
|
||||||
from .queue import DownloadQueue # noqa F401
|
from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401
|
||||||
|
from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue # noqa F401
|
||||||
|
@ -30,17 +30,6 @@ class UnknownJobIDException(Exception):
|
|||||||
"""Raised when an invalid Job is referenced."""
|
"""Raised when an invalid Job is referenced."""
|
||||||
|
|
||||||
|
|
||||||
class ModelSourceMetadata(BaseModel):
|
|
||||||
"""Information collected on a downloadable model from its source site."""
|
|
||||||
|
|
||||||
name: Optional[str] = Field(description="Human-readable name of this model")
|
|
||||||
author: Optional[str] = Field(description="Author/creator of the model")
|
|
||||||
description: Optional[str] = Field(description="Description of the model")
|
|
||||||
license: Optional[str] = Field(description="Model license terms")
|
|
||||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
|
|
||||||
tags: Optional[List[str]] = Field(description="List of descriptive tags")
|
|
||||||
|
|
||||||
|
|
||||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||||
|
|
||||||
|
|
||||||
@ -67,9 +56,6 @@ class DownloadJobBase(BaseModel):
|
|||||||
description="if true, then preserve partial downloads when cancelled or errored", default=False
|
description="if true, then preserve partial downloads when cancelled or errored", default=False
|
||||||
)
|
)
|
||||||
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
||||||
metadata: ModelSourceMetadata = Field(
|
|
||||||
description="Metadata describing download contents", default_factory=ModelSourceMetadata
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_event_handler(self, handler: DownloadEventHandler):
|
def add_event_handler(self, handler: DownloadEventHandler):
|
||||||
"""Add an event handler to the end of the handlers list."""
|
"""Add an event handler to the end of the handlers list."""
|
||||||
@ -134,7 +120,7 @@ class DownloadQueueBase(ABC):
|
|||||||
filename: Optional[Path] = None,
|
filename: Optional[Path] = None,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
event_handlers: List[DownloadEventHandler] = [],
|
||||||
) -> DownloadJobBase:
|
) -> DownloadJobBase:
|
||||||
"""
|
"""
|
||||||
Create and submit a download job.
|
Create and submit a download job.
|
||||||
@ -274,3 +260,17 @@ class DownloadQueueBase(ABC):
|
|||||||
no longer recognize the job.
|
no longer recognize the job.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
||||||
|
"""Based on the job type select the download method."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||||
|
"""
|
||||||
|
Given a job, translate its source field into a downloadable URL.
|
||||||
|
|
||||||
|
Intended to be subclassed to cover various source types.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
349
invokeai/backend/model_manager/download/model_queue.py
Normal file
349
invokeai/backend/model_manager/download/model_queue.py
Normal file
@ -0,0 +1,349 @@
|
|||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union
|
||||||
|
|
||||||
|
from huggingface_hub import HfApi, hf_hub_url
|
||||||
|
from pydantic import BaseModel, Field, parse_obj_as, validator
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
|
|
||||||
|
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase
|
||||||
|
from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadQueue
|
||||||
|
|
||||||
|
# regular expressions used to dispatch appropriate downloaders and metadata retrievers
|
||||||
|
# endpoint for civitai get-model API
|
||||||
|
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
|
||||||
|
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
|
||||||
|
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||||
|
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||||
|
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||||
|
|
||||||
|
# Regular expressions to describe repo_ids and http urls
|
||||||
|
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSourceMetadata(BaseModel):
|
||||||
|
"""Information collected on a downloadable model from its source site."""
|
||||||
|
|
||||||
|
name: Optional[str] = Field(description="Human-readable name of this model")
|
||||||
|
author: Optional[str] = Field(description="Author/creator of the model")
|
||||||
|
description: Optional[str] = Field(description="Description of the model")
|
||||||
|
license: Optional[str] = Field(description="Model license terms")
|
||||||
|
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
|
||||||
|
tags: Optional[List[str]] = Field(description="List of descriptive tags")
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadJobWithMetadata(DownloadJobRemoteSource):
|
||||||
|
"""A remote download that has metadata associated with it."""
|
||||||
|
|
||||||
|
metadata: ModelSourceMetadata = Field(
|
||||||
|
description="Metadata describing the model, derived from source", default_factory=ModelSourceMetadata
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DownloadJobRepoID(DownloadJobWithMetadata):
|
||||||
|
"""Download repo ids."""
|
||||||
|
|
||||||
|
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
|
||||||
|
subfolder: Optional[str] = Field(
|
||||||
|
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
|
||||||
|
)
|
||||||
|
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
||||||
|
subqueue: Optional[DownloadQueueBase] = Field(
|
||||||
|
description="a subqueue used for downloading the individual files in the repo_id", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
@validator("source")
|
||||||
|
@classmethod
|
||||||
|
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
||||||
|
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
|
||||||
|
raise ValueError(f"{v}: invalid repo_id format")
|
||||||
|
return v
|
||||||
|
|
||||||
|
def cleanup(self, preserve_partial_downloads: bool = False):
|
||||||
|
"""Perform action when job is completed."""
|
||||||
|
if self.subqueue:
|
||||||
|
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
|
||||||
|
self.subqueue.release()
|
||||||
|
|
||||||
|
|
||||||
|
class ModelDownloadQueue(DownloadQueue):
|
||||||
|
"""Subclass of DownloadQueue, able to retrieve metadata from HuggingFace and Civitai."""
|
||||||
|
|
||||||
|
def create_download_job(
|
||||||
|
self,
|
||||||
|
source: Union[str, Path, AnyHttpUrl],
|
||||||
|
destdir: Path,
|
||||||
|
start: bool = True,
|
||||||
|
priority: int = 10,
|
||||||
|
filename: Optional[Path] = None,
|
||||||
|
variant: Optional[str] = None,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
event_handlers: List[DownloadEventHandler] = [],
|
||||||
|
) -> DownloadJobBase:
|
||||||
|
"""Create a download job and return its ID."""
|
||||||
|
cls: Optional[Type[DownloadJobBase]] = None
|
||||||
|
kwargs: Dict[str, Optional[str]] = dict()
|
||||||
|
|
||||||
|
if re.match(HTTP_RE, str(source)):
|
||||||
|
cls = DownloadJobWithMetadata
|
||||||
|
kwargs.update(access_token=access_token)
|
||||||
|
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
||||||
|
cls = DownloadJobRepoID
|
||||||
|
kwargs.update(
|
||||||
|
variant=variant,
|
||||||
|
access_token=access_token,
|
||||||
|
)
|
||||||
|
if cls:
|
||||||
|
job = cls(
|
||||||
|
source=source,
|
||||||
|
destination=Path(destdir) / (filename or "."),
|
||||||
|
event_handlers=event_handlers,
|
||||||
|
priority=priority,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
return self.submit_download_job(job, start)
|
||||||
|
else:
|
||||||
|
return super().create_download_job(
|
||||||
|
source=source,
|
||||||
|
destdir=destdir,
|
||||||
|
start=start,
|
||||||
|
priority=priority,
|
||||||
|
filename=filename,
|
||||||
|
variant=variant,
|
||||||
|
access_token=access_token,
|
||||||
|
event_handlers=event_handlers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
||||||
|
"""Based on the job type select the download method."""
|
||||||
|
if isinstance(job, DownloadJobRepoID):
|
||||||
|
return self._download_repoid
|
||||||
|
elif isinstance(job, DownloadJobWithMetadata):
|
||||||
|
return self._download_with_resume
|
||||||
|
else:
|
||||||
|
return super().select_downloader(job)
|
||||||
|
|
||||||
|
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||||
|
"""
|
||||||
|
Fetch metadata from certain well-known URLs.
|
||||||
|
|
||||||
|
The metadata will be stashed in job.metadata, if found
|
||||||
|
Return the download URL.
|
||||||
|
"""
|
||||||
|
assert isinstance(job, DownloadJobWithMetadata)
|
||||||
|
metadata = job.metadata
|
||||||
|
url = job.source
|
||||||
|
metadata_url = url
|
||||||
|
model = None
|
||||||
|
|
||||||
|
# a Civitai download URL
|
||||||
|
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
|
||||||
|
version = match.group(1)
|
||||||
|
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
||||||
|
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
|
||||||
|
metadata.description = metadata.description or (
|
||||||
|
f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"]
|
||||||
|
)
|
||||||
|
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
|
||||||
|
|
||||||
|
# a Civitai model page with the version
|
||||||
|
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
|
||||||
|
model = match.group(1)
|
||||||
|
version = int(match.group(2))
|
||||||
|
# and without
|
||||||
|
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
|
||||||
|
model = match.group(1)
|
||||||
|
version = None
|
||||||
|
|
||||||
|
if not model:
|
||||||
|
return parse_obj_as(AnyHttpUrl, url)
|
||||||
|
|
||||||
|
if model:
|
||||||
|
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
|
||||||
|
|
||||||
|
metadata.author = metadata.author or resp["creator"]["username"]
|
||||||
|
metadata.tags = metadata.tags or resp["tags"]
|
||||||
|
metadata.license = (
|
||||||
|
metadata.license
|
||||||
|
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if version:
|
||||||
|
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
|
||||||
|
version_data = versions[0]
|
||||||
|
else:
|
||||||
|
version_data = resp["modelVersions"][0] # first one
|
||||||
|
|
||||||
|
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
|
||||||
|
metadata.description = metadata.description or (
|
||||||
|
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
|
||||||
|
if version_data.get("trainedWords")
|
||||||
|
else version_data.get("description")
|
||||||
|
)
|
||||||
|
|
||||||
|
download_url = version_data["downloadUrl"]
|
||||||
|
|
||||||
|
# return the download url
|
||||||
|
return download_url
|
||||||
|
|
||||||
|
def _download_repoid(self, job: DownloadJobBase) -> None:
|
||||||
|
"""Download a job that holds a huggingface repoid."""
|
||||||
|
|
||||||
|
def subdownload_event(subjob: DownloadJobBase):
|
||||||
|
assert isinstance(subjob, DownloadJobRemoteSource)
|
||||||
|
assert isinstance(job, DownloadJobRemoteSource)
|
||||||
|
if subjob.status == DownloadJobStatus.RUNNING:
|
||||||
|
bytes_downloaded[subjob.id] = subjob.bytes
|
||||||
|
job.bytes = sum(bytes_downloaded.values())
|
||||||
|
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||||
|
return
|
||||||
|
|
||||||
|
if subjob.status == DownloadJobStatus.ERROR:
|
||||||
|
job.error = subjob.error
|
||||||
|
job.cleanup()
|
||||||
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||||
|
return
|
||||||
|
|
||||||
|
if subjob.status == DownloadJobStatus.COMPLETED:
|
||||||
|
bytes_downloaded[subjob.id] = subjob.bytes
|
||||||
|
job.bytes = sum(bytes_downloaded.values())
|
||||||
|
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
||||||
|
return
|
||||||
|
|
||||||
|
subqueue = self.__class__(
|
||||||
|
event_handlers=[subdownload_event],
|
||||||
|
requests_session=self._requests,
|
||||||
|
quiet=True,
|
||||||
|
)
|
||||||
|
assert isinstance(job, DownloadJobRepoID)
|
||||||
|
try:
|
||||||
|
repo_id = job.source
|
||||||
|
variant = job.variant
|
||||||
|
if not job.metadata:
|
||||||
|
job.metadata = ModelSourceMetadata()
|
||||||
|
urls_to_download = self._get_repo_info(
|
||||||
|
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
|
||||||
|
)
|
||||||
|
if job.destination.name != Path(repo_id).name:
|
||||||
|
job.destination = job.destination / Path(repo_id).name
|
||||||
|
bytes_downloaded: Dict[int, int] = dict()
|
||||||
|
job.total_bytes = 0
|
||||||
|
|
||||||
|
for url, subdir, file, size in urls_to_download:
|
||||||
|
job.total_bytes += size
|
||||||
|
subqueue.create_download_job(
|
||||||
|
source=url,
|
||||||
|
destdir=job.destination / subdir,
|
||||||
|
filename=file,
|
||||||
|
variant=variant,
|
||||||
|
access_token=job.access_token,
|
||||||
|
)
|
||||||
|
except KeyboardInterrupt as excp:
|
||||||
|
raise excp
|
||||||
|
except Exception as excp:
|
||||||
|
job.error = excp
|
||||||
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||||
|
self._logger.error(job.error)
|
||||||
|
finally:
|
||||||
|
job.subqueue = subqueue
|
||||||
|
job.subqueue.join()
|
||||||
|
if job.status == DownloadJobStatus.RUNNING:
|
||||||
|
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||||
|
job.subqueue.release() # get rid of the subqueue
|
||||||
|
|
||||||
|
def _get_repo_info(
|
||||||
|
self,
|
||||||
|
repo_id: str,
|
||||||
|
metadata: ModelSourceMetadata,
|
||||||
|
variant: Optional[str] = None,
|
||||||
|
subfolder: Optional[str] = None,
|
||||||
|
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
|
||||||
|
"""
|
||||||
|
Given a repo_id and an optional variant, return list of URLs to download to get the model.
|
||||||
|
|
||||||
|
The metadata field will be updated with model metadata from HuggingFace.
|
||||||
|
|
||||||
|
Known variants currently are:
|
||||||
|
1. onnx
|
||||||
|
2. openvino
|
||||||
|
3. fp16
|
||||||
|
4. None (usually returns fp32 model)
|
||||||
|
"""
|
||||||
|
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
||||||
|
sibs = model_info.siblings
|
||||||
|
paths = [x.rfilename for x in sibs]
|
||||||
|
sizes = {x.rfilename: x.size for x in sibs}
|
||||||
|
|
||||||
|
prefix = ""
|
||||||
|
if subfolder:
|
||||||
|
prefix = f"{subfolder}/"
|
||||||
|
paths = [x for x in paths if x.startswith(prefix)]
|
||||||
|
|
||||||
|
if f"{prefix}model_index.json" in paths:
|
||||||
|
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
|
||||||
|
resp = self._requests.get(url)
|
||||||
|
resp.raise_for_status() # will raise an HTTPError on non-200 status
|
||||||
|
submodels = resp.json()
|
||||||
|
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
|
||||||
|
paths.insert(0, f"{prefix}model_index.json")
|
||||||
|
urls = [
|
||||||
|
(
|
||||||
|
hf_hub_url(repo_id, filename=x.as_posix()),
|
||||||
|
x.parent.relative_to(prefix) or Path("."),
|
||||||
|
Path(x.name),
|
||||||
|
sizes[x.as_posix()],
|
||||||
|
)
|
||||||
|
for x in self._select_variants(paths, variant)
|
||||||
|
]
|
||||||
|
if hasattr(model_info, "cardData"):
|
||||||
|
metadata.license = metadata.license or model_info.cardData.get("license")
|
||||||
|
metadata.tags = metadata.tags or model_info.tags
|
||||||
|
metadata.author = metadata.author or model_info.author
|
||||||
|
return urls
|
||||||
|
|
||||||
|
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
||||||
|
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||||
|
result = set()
|
||||||
|
basenames: Dict[Path, Path] = dict()
|
||||||
|
for p in paths:
|
||||||
|
path = Path(p)
|
||||||
|
|
||||||
|
if path.suffix == ".onnx":
|
||||||
|
if variant == "onnx":
|
||||||
|
result.add(path)
|
||||||
|
|
||||||
|
elif path.name.startswith("openvino_model"):
|
||||||
|
if variant == "openvino":
|
||||||
|
result.add(path)
|
||||||
|
|
||||||
|
elif path.suffix in [".json", ".txt"]:
|
||||||
|
result.add(path)
|
||||||
|
|
||||||
|
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
|
||||||
|
parent = path.parent
|
||||||
|
suffixes = path.suffixes
|
||||||
|
if len(suffixes) == 2:
|
||||||
|
file_variant, suffix = suffixes
|
||||||
|
basename = parent / Path(path.stem).stem
|
||||||
|
else:
|
||||||
|
file_variant = None
|
||||||
|
suffix = suffixes[0]
|
||||||
|
basename = parent / path.stem
|
||||||
|
|
||||||
|
if previous := basenames.get(basename):
|
||||||
|
if previous.suffix != ".safetensors" and suffix == ".safetensors":
|
||||||
|
basenames[basename] = path
|
||||||
|
if file_variant == f".{variant}":
|
||||||
|
basenames[basename] = path
|
||||||
|
elif not variant and not file_variant:
|
||||||
|
basenames[basename] = path
|
||||||
|
else:
|
||||||
|
basenames[basename] = path
|
||||||
|
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
for v in basenames.values():
|
||||||
|
result.add(v)
|
||||||
|
|
||||||
|
return result
|
@ -9,7 +9,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import PriorityQueue
|
from queue import PriorityQueue
|
||||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from huggingface_hub import HfApi, hf_hub_url
|
from huggingface_hub import HfApi, hf_hub_url
|
||||||
@ -21,14 +21,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
|||||||
from invokeai.backend.util import InvokeAILogger, Logger
|
from invokeai.backend.util import InvokeAILogger, Logger
|
||||||
|
|
||||||
from ..storage import DuplicateModelException
|
from ..storage import DuplicateModelException
|
||||||
from .base import (
|
from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase, UnknownJobIDException
|
||||||
DownloadEventHandler,
|
|
||||||
DownloadJobBase,
|
|
||||||
DownloadJobStatus,
|
|
||||||
DownloadQueueBase,
|
|
||||||
ModelSourceMetadata,
|
|
||||||
UnknownJobIDException,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Maximum number of bytes to download during each call to requests.iter_content()
|
# Maximum number of bytes to download during each call to requests.iter_content()
|
||||||
DOWNLOAD_CHUNK_SIZE = 100000
|
DOWNLOAD_CHUNK_SIZE = 100000
|
||||||
@ -36,17 +29,8 @@ DOWNLOAD_CHUNK_SIZE = 100000
|
|||||||
# marker that the queue is done and that thread should exit
|
# marker that the queue is done and that thread should exit
|
||||||
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
|
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
|
||||||
|
|
||||||
# endpoint for civitai get-model API
|
# regular expression for picking up a URL
|
||||||
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
|
|
||||||
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
|
|
||||||
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
|
||||||
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
|
|
||||||
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
|
||||||
|
|
||||||
# Regular expressions to describe repo_ids and http urls
|
|
||||||
HTTP_RE = r"^https?://"
|
HTTP_RE = r"^https?://"
|
||||||
REPO_ID_RE = r"^[\w-]+/[.\w-]+$"
|
|
||||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$"
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadJobPath(DownloadJobBase):
|
class DownloadJobPath(DownloadJobBase):
|
||||||
@ -69,32 +53,6 @@ class DownloadJobURL(DownloadJobRemoteSource):
|
|||||||
source: AnyHttpUrl = Field(description="URL to download")
|
source: AnyHttpUrl = Field(description="URL to download")
|
||||||
|
|
||||||
|
|
||||||
class DownloadJobRepoID(DownloadJobRemoteSource):
|
|
||||||
"""Download repo ids."""
|
|
||||||
|
|
||||||
source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)")
|
|
||||||
subfolder: Optional[str] = Field(
|
|
||||||
description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None
|
|
||||||
)
|
|
||||||
variant: Optional[str] = Field(description="Variant, such as 'fp16', to download")
|
|
||||||
subqueue: Optional["DownloadQueueBase"] = Field(
|
|
||||||
description="a subqueue used for downloading the individual files in the repo_id", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
@validator("source")
|
|
||||||
@classmethod
|
|
||||||
def proper_repo_id(cls, v: str) -> str: # noqa D102
|
|
||||||
if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v):
|
|
||||||
raise ValueError(f"{v}: invalid repo_id format")
|
|
||||||
return v
|
|
||||||
|
|
||||||
def cleanup(self, preserve_partial_downloads: bool = False):
|
|
||||||
"""Perform action when job is completed."""
|
|
||||||
if self.subqueue:
|
|
||||||
self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads)
|
|
||||||
self.subqueue.release()
|
|
||||||
|
|
||||||
|
|
||||||
class DownloadQueue(DownloadQueueBase):
|
class DownloadQueue(DownloadQueueBase):
|
||||||
"""Class for queued download of models."""
|
"""Class for queued download of models."""
|
||||||
|
|
||||||
@ -145,7 +103,7 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
filename: Optional[Path] = None,
|
filename: Optional[Path] = None,
|
||||||
variant: Optional[str] = None,
|
variant: Optional[str] = None,
|
||||||
access_token: Optional[str] = None,
|
access_token: Optional[str] = None,
|
||||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
event_handlers: List[DownloadEventHandler] = [],
|
||||||
) -> DownloadJobBase:
|
) -> DownloadJobBase:
|
||||||
"""Create a download job and return its ID."""
|
"""Create a download job and return its ID."""
|
||||||
kwargs: Dict[str, Optional[str]] = dict()
|
kwargs: Dict[str, Optional[str]] = dict()
|
||||||
@ -153,12 +111,6 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
cls = DownloadJobBase
|
cls = DownloadJobBase
|
||||||
if Path(source).exists():
|
if Path(source).exists():
|
||||||
cls = DownloadJobPath
|
cls = DownloadJobPath
|
||||||
elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)):
|
|
||||||
cls = DownloadJobRepoID
|
|
||||||
kwargs.update(
|
|
||||||
variant=variant,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
elif re.match(HTTP_RE, str(source)):
|
elif re.match(HTTP_RE, str(source)):
|
||||||
cls = DownloadJobURL
|
cls = DownloadJobURL
|
||||||
kwargs.update(access_token=access_token)
|
kwargs.update(access_token=access_token)
|
||||||
@ -168,7 +120,7 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
job = cls(
|
job = cls(
|
||||||
source=source,
|
source=source,
|
||||||
destination=Path(destdir) / (filename or "."),
|
destination=Path(destdir) / (filename or "."),
|
||||||
event_handlers=(event_handlers or self._event_handlers),
|
event_handlers=event_handlers,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
@ -349,89 +301,32 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen)
|
||||||
if not self._quiet:
|
if not self._quiet:
|
||||||
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
self._logger.info(f"{job.source}: Downloading to {job.destination}")
|
||||||
if isinstance(job, DownloadJobURL):
|
do_download = self.select_downloader(job)
|
||||||
self._download_with_resume(job)
|
do_download(job)
|
||||||
elif isinstance(job, DownloadJobRepoID):
|
|
||||||
self._download_repoid(job)
|
|
||||||
elif isinstance(job, DownloadJobPath):
|
|
||||||
self._download_path(job)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
|
|
||||||
|
|
||||||
if job.status == DownloadJobStatus.CANCELLED:
|
if job.status == DownloadJobStatus.CANCELLED:
|
||||||
self._cleanup_cancelled_job(job)
|
self._cleanup_cancelled_job(job)
|
||||||
|
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
|
|
||||||
def _get_metadata_and_url(self, job: DownloadJobRemoteSource) -> AnyHttpUrl:
|
def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]:
|
||||||
"""
|
"""Based on the job type select the download method."""
|
||||||
Fetch metadata from certain well-known URLs.
|
if isinstance(job, DownloadJobURL):
|
||||||
|
return self._download_with_resume
|
||||||
|
elif isinstance(job, DownloadJobPath):
|
||||||
|
return self._download_path
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}")
|
||||||
|
|
||||||
The metadata will be stashed in job.metadata, if found
|
def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl:
|
||||||
Return the download URL.
|
return job.source
|
||||||
"""
|
|
||||||
metadata = job.metadata
|
|
||||||
url = job.source
|
|
||||||
metadata_url = url
|
|
||||||
model = None
|
|
||||||
|
|
||||||
# a Civitai download URL
|
|
||||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)):
|
|
||||||
version = match.group(1)
|
|
||||||
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
|
||||||
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
|
|
||||||
metadata.description = metadata.description or (
|
|
||||||
f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"]
|
|
||||||
)
|
|
||||||
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
|
|
||||||
|
|
||||||
# a Civitai model page with the version
|
|
||||||
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)):
|
|
||||||
model = match.group(1)
|
|
||||||
version = int(match.group(2))
|
|
||||||
# and without
|
|
||||||
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)):
|
|
||||||
model = match.group(1)
|
|
||||||
version = None
|
|
||||||
|
|
||||||
if not model:
|
|
||||||
return parse_obj_as(AnyHttpUrl, url)
|
|
||||||
|
|
||||||
if model:
|
|
||||||
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
|
|
||||||
|
|
||||||
metadata.author = metadata.author or resp["creator"]["username"]
|
|
||||||
metadata.tags = metadata.tags or resp["tags"]
|
|
||||||
metadata.license = (
|
|
||||||
metadata.license
|
|
||||||
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if version:
|
|
||||||
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
|
|
||||||
version_data = versions[0]
|
|
||||||
else:
|
|
||||||
version_data = resp["modelVersions"][0] # first one
|
|
||||||
|
|
||||||
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
|
|
||||||
metadata.description = metadata.description or (
|
|
||||||
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
|
|
||||||
if version_data.get("trainedWords")
|
|
||||||
else version_data.get("description")
|
|
||||||
)
|
|
||||||
|
|
||||||
download_url = version_data["downloadUrl"]
|
|
||||||
|
|
||||||
# return the download url
|
|
||||||
return download_url
|
|
||||||
|
|
||||||
def _download_with_resume(self, job: DownloadJobBase):
|
def _download_with_resume(self, job: DownloadJobBase):
|
||||||
"""Do the actual download."""
|
"""Do the actual download."""
|
||||||
dest = None
|
dest = None
|
||||||
try:
|
try:
|
||||||
assert isinstance(job, DownloadJobRemoteSource)
|
assert isinstance(job, DownloadJobRemoteSource)
|
||||||
url = self._get_metadata_and_url(job)
|
url = self.get_url_for_job(job)
|
||||||
|
|
||||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||||
open_mode = "wb"
|
open_mode = "wb"
|
||||||
exist_size = 0
|
exist_size = 0
|
||||||
@ -497,7 +392,8 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
if job.bytes - last_report_bytes >= report_delta:
|
if job.bytes - last_report_bytes >= report_delta:
|
||||||
last_report_bytes = job.bytes
|
last_report_bytes = job.bytes
|
||||||
self._update_job_status(job)
|
self._update_job_status(job)
|
||||||
|
if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored
|
||||||
|
return
|
||||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||||
except KeyboardInterrupt as excp:
|
except KeyboardInterrupt as excp:
|
||||||
raise excp
|
raise excp
|
||||||
@ -541,166 +437,6 @@ class DownloadQueue(DownloadQueueBase):
|
|||||||
job.error = excp
|
job.error = excp
|
||||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||||
|
|
||||||
def _download_repoid(self, job: DownloadJobRepoID):
|
|
||||||
"""Download a job that holds a huggingface repoid."""
|
|
||||||
|
|
||||||
def subdownload_event(subjob: DownloadJobBase):
|
|
||||||
assert isinstance(subjob, DownloadJobRemoteSource)
|
|
||||||
if subjob.status == DownloadJobStatus.RUNNING:
|
|
||||||
bytes_downloaded[subjob.id] = subjob.bytes
|
|
||||||
job.bytes = sum(bytes_downloaded.values())
|
|
||||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
|
||||||
return
|
|
||||||
|
|
||||||
if subjob.status == DownloadJobStatus.ERROR:
|
|
||||||
job.error = subjob.error
|
|
||||||
job.cleanup()
|
|
||||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
|
||||||
return
|
|
||||||
|
|
||||||
if subjob.status == DownloadJobStatus.COMPLETED:
|
|
||||||
bytes_downloaded[subjob.id] = subjob.bytes
|
|
||||||
job.bytes = sum(bytes_downloaded.values())
|
|
||||||
self._update_job_status(job, DownloadJobStatus.RUNNING)
|
|
||||||
return
|
|
||||||
|
|
||||||
subqueue = self.__class__(
|
|
||||||
event_handlers=[subdownload_event],
|
|
||||||
requests_session=self._requests,
|
|
||||||
quiet=True,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
assert isinstance(job, DownloadJobRepoID)
|
|
||||||
repo_id = job.source
|
|
||||||
variant = job.variant
|
|
||||||
if not job.metadata:
|
|
||||||
job.metadata = ModelSourceMetadata()
|
|
||||||
urls_to_download = self._get_repo_info(
|
|
||||||
repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder
|
|
||||||
)
|
|
||||||
if job.destination.name != Path(repo_id).name:
|
|
||||||
job.destination = job.destination / Path(repo_id).name
|
|
||||||
bytes_downloaded: Dict[int, int] = dict()
|
|
||||||
job.total_bytes = 0
|
|
||||||
|
|
||||||
for url, subdir, file, size in urls_to_download:
|
|
||||||
job.total_bytes += size
|
|
||||||
subqueue.create_download_job(
|
|
||||||
source=url,
|
|
||||||
destdir=job.destination / subdir,
|
|
||||||
filename=file,
|
|
||||||
variant=variant,
|
|
||||||
access_token=job.access_token,
|
|
||||||
)
|
|
||||||
except KeyboardInterrupt as excp:
|
|
||||||
raise excp
|
|
||||||
except Exception as excp:
|
|
||||||
job.error = excp
|
|
||||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
|
||||||
self._logger.error(job.error)
|
|
||||||
finally:
|
|
||||||
job.subqueue = subqueue
|
|
||||||
job.subqueue.join()
|
|
||||||
if job.status == DownloadJobStatus.RUNNING:
|
|
||||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
|
||||||
job.subqueue.release() # get rid of the subqueue
|
|
||||||
|
|
||||||
def _get_repo_info(
|
|
||||||
self,
|
|
||||||
repo_id: str,
|
|
||||||
metadata: ModelSourceMetadata,
|
|
||||||
variant: Optional[str] = None,
|
|
||||||
subfolder: Optional[str] = None,
|
|
||||||
) -> List[Tuple[AnyHttpUrl, Path, Path, int]]:
|
|
||||||
"""
|
|
||||||
Given a repo_id and an optional variant, return list of URLs to download to get the model.
|
|
||||||
The metadata field will be updated with model metadata from HuggingFace.
|
|
||||||
|
|
||||||
Known variants currently are:
|
|
||||||
1. onnx
|
|
||||||
2. openvino
|
|
||||||
3. fp16
|
|
||||||
4. None (usually returns fp32 model)
|
|
||||||
"""
|
|
||||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
|
||||||
sibs = model_info.siblings
|
|
||||||
paths = [x.rfilename for x in sibs]
|
|
||||||
sizes = {x.rfilename: x.size for x in sibs}
|
|
||||||
|
|
||||||
prefix = ""
|
|
||||||
if subfolder:
|
|
||||||
prefix = f"{subfolder}/"
|
|
||||||
paths = [x for x in paths if x.startswith(prefix)]
|
|
||||||
|
|
||||||
if f"{prefix}model_index.json" in paths:
|
|
||||||
url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder)
|
|
||||||
resp = self._requests.get(url)
|
|
||||||
resp.raise_for_status() # will raise an HTTPError on non-200 status
|
|
||||||
submodels = resp.json()
|
|
||||||
paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels]
|
|
||||||
paths.insert(0, f"{prefix}model_index.json")
|
|
||||||
urls = [
|
|
||||||
(
|
|
||||||
hf_hub_url(repo_id, filename=x.as_posix()),
|
|
||||||
x.parent.relative_to(prefix) or Path("."),
|
|
||||||
Path(x.name),
|
|
||||||
sizes[x.as_posix()],
|
|
||||||
)
|
|
||||||
for x in self._select_variants(paths, variant)
|
|
||||||
]
|
|
||||||
if hasattr(model_info, "cardData"):
|
|
||||||
metadata.license = metadata.license or model_info.cardData.get("license")
|
|
||||||
metadata.tags = metadata.tags or model_info.tags
|
|
||||||
metadata.author = metadata.author or model_info.author
|
|
||||||
return urls
|
|
||||||
|
|
||||||
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
|
||||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
|
||||||
result = set()
|
|
||||||
basenames: Dict[Path, Path] = dict()
|
|
||||||
for p in paths:
|
|
||||||
path = Path(p)
|
|
||||||
|
|
||||||
if path.suffix == ".onnx":
|
|
||||||
if variant == "onnx":
|
|
||||||
result.add(path)
|
|
||||||
|
|
||||||
elif path.name.startswith("openvino_model"):
|
|
||||||
if variant == "openvino":
|
|
||||||
result.add(path)
|
|
||||||
|
|
||||||
elif path.suffix in [".json", ".txt"]:
|
|
||||||
result.add(path)
|
|
||||||
|
|
||||||
elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]:
|
|
||||||
parent = path.parent
|
|
||||||
suffixes = path.suffixes
|
|
||||||
if len(suffixes) == 2:
|
|
||||||
file_variant, suffix = suffixes
|
|
||||||
basename = parent / Path(path.stem).stem
|
|
||||||
else:
|
|
||||||
file_variant = None
|
|
||||||
suffix = suffixes[0]
|
|
||||||
basename = parent / path.stem
|
|
||||||
|
|
||||||
if previous := basenames.get(basename):
|
|
||||||
if previous.suffix != ".safetensors" and suffix == ".safetensors":
|
|
||||||
basenames[basename] = path
|
|
||||||
if file_variant == f".{variant}":
|
|
||||||
basenames[basename] = path
|
|
||||||
elif not variant and not file_variant:
|
|
||||||
basenames[basename] = path
|
|
||||||
else:
|
|
||||||
basenames[basename] = path
|
|
||||||
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
for v in basenames.values():
|
|
||||||
result.add(v)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def _download_path(self, job: DownloadJobBase):
|
def _download_path(self, job: DownloadJobBase):
|
||||||
"""Call when the source is a Path or pathlike object."""
|
"""Call when the source is a Path or pathlike object."""
|
||||||
source = Path(job.source).resolve()
|
source = Path(job.source).resolve()
|
||||||
|
@ -72,14 +72,20 @@ from .config import (
|
|||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
SubModelType,
|
SubModelType,
|
||||||
)
|
)
|
||||||
from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata
|
from .download import (
|
||||||
from .download.queue import (
|
DownloadEventHandler,
|
||||||
|
DownloadJobBase,
|
||||||
|
DownloadJobPath,
|
||||||
|
DownloadJobURL,
|
||||||
|
DownloadQueueBase,
|
||||||
|
ModelDownloadQueue,
|
||||||
|
ModelSourceMetadata,
|
||||||
|
)
|
||||||
|
from .download.model_queue import (
|
||||||
HTTP_RE,
|
HTTP_RE,
|
||||||
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
|
REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE,
|
||||||
DownloadJobRemoteSource,
|
|
||||||
DownloadJobPath,
|
|
||||||
DownloadJobRepoID,
|
DownloadJobRepoID,
|
||||||
DownloadJobURL,
|
DownloadJobWithMetadata,
|
||||||
)
|
)
|
||||||
from .hash import FastModelHash
|
from .hash import FastModelHash
|
||||||
from .models import InvalidModelException
|
from .models import InvalidModelException
|
||||||
@ -88,7 +94,7 @@ from .search import ModelSearch
|
|||||||
from .storage import DuplicateModelException, ModelConfigStore
|
from .storage import DuplicateModelException, ModelConfigStore
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallJob(DownloadJobRemoteSource):
|
class ModelInstallJob(DownloadJobBase):
|
||||||
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
"""This is a version of DownloadJobBase that has an additional slot for the model key and probe info."""
|
||||||
|
|
||||||
model_key: Optional[str] = Field(
|
model_key: Optional[str] = Field(
|
||||||
@ -100,7 +106,7 @@ class ModelInstallJob(DownloadJobRemoteSource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelInstallURLJob(DownloadJobURL, ModelInstallJob):
|
class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob):
|
||||||
"""Job for installing URLs."""
|
"""Job for installing URLs."""
|
||||||
|
|
||||||
|
|
||||||
@ -398,7 +404,7 @@ class ModelInstall(ModelInstallBase):
|
|||||||
self._app_config = config or InvokeAIAppConfig.get_config()
|
self._app_config = config or InvokeAIAppConfig.get_config()
|
||||||
self._logger = logger or InvokeAILogger.get_logger(config=self._app_config)
|
self._logger = logger or InvokeAILogger.get_logger(config=self._app_config)
|
||||||
self._store = store or ModelRecordServiceBase.get_impl(self._app_config)
|
self._store = store or ModelRecordServiceBase.get_impl(self._app_config)
|
||||||
self._download_queue = download or DownloadQueue(config=self._app_config, event_handlers=event_handlers)
|
self._download_queue = download or ModelDownloadQueue(config=self._app_config, event_handlers=event_handlers)
|
||||||
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
|
self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict()
|
||||||
self._installed = set()
|
self._installed = set()
|
||||||
self._tmpdir = None
|
self._tmpdir = None
|
||||||
|
@ -8,11 +8,11 @@ import requests
|
|||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from requests_testadapter import TestAdapter
|
from requests_testadapter import TestAdapter
|
||||||
|
|
||||||
import invokeai.backend.model_manager.download.queue as download_queue
|
import invokeai.backend.model_manager.download.model_queue as download_queue
|
||||||
from invokeai.backend.model_manager.download import (
|
from invokeai.backend.model_manager.download import (
|
||||||
DownloadJobBase,
|
DownloadJobBase,
|
||||||
DownloadJobStatus,
|
DownloadJobStatus,
|
||||||
DownloadQueue,
|
ModelDownloadQueue,
|
||||||
UnknownJobIDException,
|
UnknownJobIDException,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -147,7 +147,7 @@ def test_basic_queue_download():
|
|||||||
def event_handler(job: DownloadJobBase):
|
def event_handler(job: DownloadJobBase):
|
||||||
events.append(job.status)
|
events.append(job.status)
|
||||||
|
|
||||||
queue = DownloadQueue(
|
queue = ModelDownloadQueue(
|
||||||
requests_session=session,
|
requests_session=session,
|
||||||
event_handlers=[event_handler],
|
event_handlers=[event_handler],
|
||||||
)
|
)
|
||||||
@ -167,7 +167,7 @@ def test_basic_queue_download():
|
|||||||
|
|
||||||
|
|
||||||
def test_queue_priority():
|
def test_queue_priority():
|
||||||
queue = DownloadQueue(
|
queue = ModelDownloadQueue(
|
||||||
requests_session=session,
|
requests_session=session,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -200,7 +200,7 @@ def test_repo_id_download():
|
|||||||
if not INTERNET_AVAILABLE:
|
if not INTERNET_AVAILABLE:
|
||||||
return
|
return
|
||||||
repo_id = "stabilityai/stable-diffusion-2-1"
|
repo_id = "stabilityai/stable-diffusion-2-1"
|
||||||
queue = DownloadQueue(
|
queue = ModelDownloadQueue(
|
||||||
requests_session=session,
|
requests_session=session,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -224,7 +224,7 @@ def test_repo_id_download():
|
|||||||
|
|
||||||
|
|
||||||
def test_bad_urls():
|
def test_bad_urls():
|
||||||
queue = DownloadQueue(
|
queue = ModelDownloadQueue(
|
||||||
requests_session=session,
|
requests_session=session,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -270,9 +270,11 @@ def test_bad_urls():
|
|||||||
|
|
||||||
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
|
def test_pause_cancel_url(): # this one is tricky because of potential race conditions
|
||||||
def event_handler(job: DownloadJobBase):
|
def event_handler(job: DownloadJobBase):
|
||||||
time.sleep(0.5) # slow down the thread by blocking it just a bit at every step
|
if job.id == 0:
|
||||||
|
print(job.status, job.bytes)
|
||||||
|
time.sleep(0.5) # slow down the thread so that we can recover the paused state
|
||||||
|
|
||||||
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
|
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False)
|
||||||
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False)
|
||||||
@ -322,7 +324,7 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con
|
|||||||
return
|
return
|
||||||
|
|
||||||
repo_id = "stabilityai/stable-diffusion-2-1"
|
repo_id = "stabilityai/stable-diffusion-2-1"
|
||||||
queue = DownloadQueue(requests_session=session, event_handlers=[event_handler])
|
queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler])
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
|
with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2:
|
||||||
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)
|
job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False)
|
||||||
|
@ -55,7 +55,8 @@ def test_install(datadir: Path):
|
|||||||
mm_install = ModelInstallService(config=config, store=mm_store, event_bus=event_bus)
|
mm_install = ModelInstallService(config=config, store=mm_store, event_bus=event_bus)
|
||||||
|
|
||||||
source = datadir / TEST_MODEL
|
source = datadir / TEST_MODEL
|
||||||
mm_install.install_model(source=source)
|
job = mm_install.install_model(source=source)
|
||||||
|
print(f"DEBUG: job={type(job)}")
|
||||||
id_map = mm_install.wait_for_installs()
|
id_map = mm_install.wait_for_installs()
|
||||||
print(id_map)
|
print(id_map)
|
||||||
assert source in id_map, "model did not install; id_map empty"
|
assert source in id_map, "model did not install; id_map empty"
|
||||||
|
Reference in New Issue
Block a user