diff --git a/invokeai/backend/model_manager/download/__init__.py b/invokeai/backend/model_manager/download/__init__.py index 370f9127b4..99ae479b48 100644 --- a/invokeai/backend/model_manager/download/__init__.py +++ b/invokeai/backend/model_manager/download/__init__.py @@ -5,7 +5,7 @@ from .base import ( # noqa F401 DownloadJobBase, DownloadJobStatus, DownloadQueueBase, - ModelSourceMetadata, UnknownJobIDException, ) -from .queue import DownloadQueue # noqa F401 +from .model_queue import ModelDownloadQueue, ModelSourceMetadata # noqa F401 +from .queue import DownloadJobPath, DownloadJobURL, DownloadQueue # noqa F401 diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 5240c68f6c..56998a806a 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -30,17 +30,6 @@ class UnknownJobIDException(Exception): """Raised when an invalid Job is referenced.""" -class ModelSourceMetadata(BaseModel): - """Information collected on a downloadable model from its source site.""" - - name: Optional[str] = Field(description="Human-readable name of this model") - author: Optional[str] = Field(description="Author/creator of the model") - description: Optional[str] = Field(description="Description of the model") - license: Optional[str] = Field(description="Model license terms") - thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model") - tags: Optional[List[str]] = Field(description="List of descriptive tags") - - DownloadEventHandler = Callable[["DownloadJobBase"], None] @@ -67,9 +56,6 @@ class DownloadJobBase(BaseModel): description="if true, then preserve partial downloads when cancelled or errored", default=False ) error: Optional[Exception] = Field(default=None, description="Exception that caused an error") - metadata: ModelSourceMetadata = Field( - description="Metadata describing download contents", default_factory=ModelSourceMetadata - ) def add_event_handler(self, handler: DownloadEventHandler): """Add an event handler to the end of the handlers list.""" @@ -134,7 +120,7 @@ class DownloadQueueBase(ABC): filename: Optional[Path] = None, variant: Optional[str] = None, access_token: Optional[str] = None, - event_handlers: Optional[List[DownloadEventHandler]] = None, + event_handlers: List[DownloadEventHandler] = [], ) -> DownloadJobBase: """ Create and submit a download job. @@ -274,3 +260,17 @@ class DownloadQueueBase(ABC): no longer recognize the job. """ pass + + @abstractmethod + def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]: + """Based on the job type select the download method.""" + pass + + @abstractmethod + def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl: + """ + Given a job, translate its source field into a downloadable URL. + + Intended to be subclassed to cover various source types. + """ + pass diff --git a/invokeai/backend/model_manager/download/model_queue.py b/invokeai/backend/model_manager/download/model_queue.py new file mode 100644 index 0000000000..a1dc3c6ddb --- /dev/null +++ b/invokeai/backend/model_manager/download/model_queue.py @@ -0,0 +1,349 @@ +import re +from pathlib import Path +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union + +from huggingface_hub import HfApi, hf_hub_url +from pydantic import BaseModel, Field, parse_obj_as, validator +from pydantic.networks import AnyHttpUrl + +from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase +from .queue import HTTP_RE, DownloadJobRemoteSource, DownloadQueue + +# regular expressions used to dispatch appropriate downloaders and metadata retrievers +# endpoint for civitai get-model API +CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)" +CIVITAI_MODEL_PAGE = "https://civitai.com/models/" +CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)" +CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/" +CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/" + +# Regular expressions to describe repo_ids and http urls +REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$" + + +class ModelSourceMetadata(BaseModel): + """Information collected on a downloadable model from its source site.""" + + name: Optional[str] = Field(description="Human-readable name of this model") + author: Optional[str] = Field(description="Author/creator of the model") + description: Optional[str] = Field(description="Description of the model") + license: Optional[str] = Field(description="Model license terms") + thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model") + tags: Optional[List[str]] = Field(description="List of descriptive tags") + + +class DownloadJobWithMetadata(DownloadJobRemoteSource): + """A remote download that has metadata associated with it.""" + + metadata: ModelSourceMetadata = Field( + description="Metadata describing the model, derived from source", default_factory=ModelSourceMetadata + ) + + +class DownloadJobRepoID(DownloadJobWithMetadata): + """Download repo ids.""" + + source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)") + subfolder: Optional[str] = Field( + description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None + ) + variant: Optional[str] = Field(description="Variant, such as 'fp16', to download") + subqueue: Optional[DownloadQueueBase] = Field( + description="a subqueue used for downloading the individual files in the repo_id", default=None + ) + + @validator("source") + @classmethod + def proper_repo_id(cls, v: str) -> str: # noqa D102 + if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v): + raise ValueError(f"{v}: invalid repo_id format") + return v + + def cleanup(self, preserve_partial_downloads: bool = False): + """Perform action when job is completed.""" + if self.subqueue: + self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads) + self.subqueue.release() + + +class ModelDownloadQueue(DownloadQueue): + """Subclass of DownloadQueue, able to retrieve metadata from HuggingFace and Civitai.""" + + def create_download_job( + self, + source: Union[str, Path, AnyHttpUrl], + destdir: Path, + start: bool = True, + priority: int = 10, + filename: Optional[Path] = None, + variant: Optional[str] = None, + access_token: Optional[str] = None, + event_handlers: List[DownloadEventHandler] = [], + ) -> DownloadJobBase: + """Create a download job and return its ID.""" + cls: Optional[Type[DownloadJobBase]] = None + kwargs: Dict[str, Optional[str]] = dict() + + if re.match(HTTP_RE, str(source)): + cls = DownloadJobWithMetadata + kwargs.update(access_token=access_token) + elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)): + cls = DownloadJobRepoID + kwargs.update( + variant=variant, + access_token=access_token, + ) + if cls: + job = cls( + source=source, + destination=Path(destdir) / (filename or "."), + event_handlers=event_handlers, + priority=priority, + **kwargs, + ) + return self.submit_download_job(job, start) + else: + return super().create_download_job( + source=source, + destdir=destdir, + start=start, + priority=priority, + filename=filename, + variant=variant, + access_token=access_token, + event_handlers=event_handlers, + ) + + def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]: + """Based on the job type select the download method.""" + if isinstance(job, DownloadJobRepoID): + return self._download_repoid + elif isinstance(job, DownloadJobWithMetadata): + return self._download_with_resume + else: + return super().select_downloader(job) + + def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl: + """ + Fetch metadata from certain well-known URLs. + + The metadata will be stashed in job.metadata, if found + Return the download URL. + """ + assert isinstance(job, DownloadJobWithMetadata) + metadata = job.metadata + url = job.source + metadata_url = url + model = None + + # a Civitai download URL + if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)): + version = match.group(1) + resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json() + metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"] + metadata.description = metadata.description or ( + f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"] + ) + metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}" + + # a Civitai model page with the version + if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)): + model = match.group(1) + version = int(match.group(2)) + # and without + elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)): + model = match.group(1) + version = None + + if not model: + return parse_obj_as(AnyHttpUrl, url) + + if model: + resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json() + + metadata.author = metadata.author or resp["creator"]["username"] + metadata.tags = metadata.tags or resp["tags"] + metadata.license = ( + metadata.license + or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}" + ) + + if version: + versions = [x for x in resp["modelVersions"] if int(x["id"]) == version] + version_data = versions[0] + else: + version_data = resp["modelVersions"][0] # first one + + metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url + metadata.description = metadata.description or ( + f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}" + if version_data.get("trainedWords") + else version_data.get("description") + ) + + download_url = version_data["downloadUrl"] + + # return the download url + return download_url + + def _download_repoid(self, job: DownloadJobBase) -> None: + """Download a job that holds a huggingface repoid.""" + + def subdownload_event(subjob: DownloadJobBase): + assert isinstance(subjob, DownloadJobRemoteSource) + assert isinstance(job, DownloadJobRemoteSource) + if subjob.status == DownloadJobStatus.RUNNING: + bytes_downloaded[subjob.id] = subjob.bytes + job.bytes = sum(bytes_downloaded.values()) + self._update_job_status(job, DownloadJobStatus.RUNNING) + return + + if subjob.status == DownloadJobStatus.ERROR: + job.error = subjob.error + job.cleanup() + self._update_job_status(job, DownloadJobStatus.ERROR) + return + + if subjob.status == DownloadJobStatus.COMPLETED: + bytes_downloaded[subjob.id] = subjob.bytes + job.bytes = sum(bytes_downloaded.values()) + self._update_job_status(job, DownloadJobStatus.RUNNING) + return + + subqueue = self.__class__( + event_handlers=[subdownload_event], + requests_session=self._requests, + quiet=True, + ) + assert isinstance(job, DownloadJobRepoID) + try: + repo_id = job.source + variant = job.variant + if not job.metadata: + job.metadata = ModelSourceMetadata() + urls_to_download = self._get_repo_info( + repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder + ) + if job.destination.name != Path(repo_id).name: + job.destination = job.destination / Path(repo_id).name + bytes_downloaded: Dict[int, int] = dict() + job.total_bytes = 0 + + for url, subdir, file, size in urls_to_download: + job.total_bytes += size + subqueue.create_download_job( + source=url, + destdir=job.destination / subdir, + filename=file, + variant=variant, + access_token=job.access_token, + ) + except KeyboardInterrupt as excp: + raise excp + except Exception as excp: + job.error = excp + self._update_job_status(job, DownloadJobStatus.ERROR) + self._logger.error(job.error) + finally: + job.subqueue = subqueue + job.subqueue.join() + if job.status == DownloadJobStatus.RUNNING: + self._update_job_status(job, DownloadJobStatus.COMPLETED) + job.subqueue.release() # get rid of the subqueue + + def _get_repo_info( + self, + repo_id: str, + metadata: ModelSourceMetadata, + variant: Optional[str] = None, + subfolder: Optional[str] = None, + ) -> List[Tuple[AnyHttpUrl, Path, Path, int]]: + """ + Given a repo_id and an optional variant, return list of URLs to download to get the model. + + The metadata field will be updated with model metadata from HuggingFace. + + Known variants currently are: + 1. onnx + 2. openvino + 3. fp16 + 4. None (usually returns fp32 model) + """ + model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True) + sibs = model_info.siblings + paths = [x.rfilename for x in sibs] + sizes = {x.rfilename: x.size for x in sibs} + + prefix = "" + if subfolder: + prefix = f"{subfolder}/" + paths = [x for x in paths if x.startswith(prefix)] + + if f"{prefix}model_index.json" in paths: + url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder) + resp = self._requests.get(url) + resp.raise_for_status() # will raise an HTTPError on non-200 status + submodels = resp.json() + paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels] + paths.insert(0, f"{prefix}model_index.json") + urls = [ + ( + hf_hub_url(repo_id, filename=x.as_posix()), + x.parent.relative_to(prefix) or Path("."), + Path(x.name), + sizes[x.as_posix()], + ) + for x in self._select_variants(paths, variant) + ] + if hasattr(model_info, "cardData"): + metadata.license = metadata.license or model_info.cardData.get("license") + metadata.tags = metadata.tags or model_info.tags + metadata.author = metadata.author or model_info.author + return urls + + def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]: + """Select the proper variant files from a list of HuggingFace repo_id paths.""" + result = set() + basenames: Dict[Path, Path] = dict() + for p in paths: + path = Path(p) + + if path.suffix == ".onnx": + if variant == "onnx": + result.add(path) + + elif path.name.startswith("openvino_model"): + if variant == "openvino": + result.add(path) + + elif path.suffix in [".json", ".txt"]: + result.add(path) + + elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]: + parent = path.parent + suffixes = path.suffixes + if len(suffixes) == 2: + file_variant, suffix = suffixes + basename = parent / Path(path.stem).stem + else: + file_variant = None + suffix = suffixes[0] + basename = parent / path.stem + + if previous := basenames.get(basename): + if previous.suffix != ".safetensors" and suffix == ".safetensors": + basenames[basename] = path + if file_variant == f".{variant}": + basenames[basename] = path + elif not variant and not file_variant: + basenames[basename] = path + else: + basenames[basename] = path + + else: + continue + + for v in basenames.values(): + result.add(v) + + return result diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index 0c76a1802b..f25424a44f 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -9,7 +9,7 @@ import time import traceback from pathlib import Path from queue import PriorityQueue -from typing import Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Set, Tuple, Union import requests from huggingface_hub import HfApi, hf_hub_url @@ -21,14 +21,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util import InvokeAILogger, Logger from ..storage import DuplicateModelException -from .base import ( - DownloadEventHandler, - DownloadJobBase, - DownloadJobStatus, - DownloadQueueBase, - ModelSourceMetadata, - UnknownJobIDException, -) +from .base import DownloadEventHandler, DownloadJobBase, DownloadJobStatus, DownloadQueueBase, UnknownJobIDException # Maximum number of bytes to download during each call to requests.iter_content() DOWNLOAD_CHUNK_SIZE = 100000 @@ -36,17 +29,8 @@ DOWNLOAD_CHUNK_SIZE = 100000 # marker that the queue is done and that thread should exit STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/") -# endpoint for civitai get-model API -CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)" -CIVITAI_MODEL_PAGE = "https://civitai.com/models/" -CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)" -CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/" -CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/" - -# Regular expressions to describe repo_ids and http urls +# regular expression for picking up a URL HTTP_RE = r"^https?://" -REPO_ID_RE = r"^[\w-]+/[.\w-]+$" -REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE = r"^([.\w-]+/[.\w-]+)(?::([.\w-]+))?$" class DownloadJobPath(DownloadJobBase): @@ -69,32 +53,6 @@ class DownloadJobURL(DownloadJobRemoteSource): source: AnyHttpUrl = Field(description="URL to download") -class DownloadJobRepoID(DownloadJobRemoteSource): - """Download repo ids.""" - - source: str = Field(description="A repo_id (foo/bar), or a repo_id with a subfolder (foo/far:v2)") - subfolder: Optional[str] = Field( - description="Provide when the desired model is in a subfolder of the repo_id's distro", default=None - ) - variant: Optional[str] = Field(description="Variant, such as 'fp16', to download") - subqueue: Optional["DownloadQueueBase"] = Field( - description="a subqueue used for downloading the individual files in the repo_id", default=None - ) - - @validator("source") - @classmethod - def proper_repo_id(cls, v: str) -> str: # noqa D102 - if not re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, v): - raise ValueError(f"{v}: invalid repo_id format") - return v - - def cleanup(self, preserve_partial_downloads: bool = False): - """Perform action when job is completed.""" - if self.subqueue: - self.subqueue.cancel_all_jobs(preserve_partial=preserve_partial_downloads) - self.subqueue.release() - - class DownloadQueue(DownloadQueueBase): """Class for queued download of models.""" @@ -145,7 +103,7 @@ class DownloadQueue(DownloadQueueBase): filename: Optional[Path] = None, variant: Optional[str] = None, access_token: Optional[str] = None, - event_handlers: Optional[List[DownloadEventHandler]] = None, + event_handlers: List[DownloadEventHandler] = [], ) -> DownloadJobBase: """Create a download job and return its ID.""" kwargs: Dict[str, Optional[str]] = dict() @@ -153,12 +111,6 @@ class DownloadQueue(DownloadQueueBase): cls = DownloadJobBase if Path(source).exists(): cls = DownloadJobPath - elif re.match(REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, str(source)): - cls = DownloadJobRepoID - kwargs.update( - variant=variant, - access_token=access_token, - ) elif re.match(HTTP_RE, str(source)): cls = DownloadJobURL kwargs.update(access_token=access_token) @@ -168,7 +120,7 @@ class DownloadQueue(DownloadQueueBase): job = cls( source=source, destination=Path(destdir) / (filename or "."), - event_handlers=(event_handlers or self._event_handlers), + event_handlers=event_handlers, priority=priority, **kwargs, ) @@ -349,89 +301,32 @@ class DownloadQueue(DownloadQueueBase): if job.status == DownloadJobStatus.ENQUEUED: # Don't do anything for non-enqueued jobs (shouldn't happen) if not self._quiet: self._logger.info(f"{job.source}: Downloading to {job.destination}") - if isinstance(job, DownloadJobURL): - self._download_with_resume(job) - elif isinstance(job, DownloadJobRepoID): - self._download_repoid(job) - elif isinstance(job, DownloadJobPath): - self._download_path(job) - else: - raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}") + do_download = self.select_downloader(job) + do_download(job) if job.status == DownloadJobStatus.CANCELLED: self._cleanup_cancelled_job(job) self._queue.task_done() - def _get_metadata_and_url(self, job: DownloadJobRemoteSource) -> AnyHttpUrl: - """ - Fetch metadata from certain well-known URLs. + def select_downloader(self, job: DownloadJobBase) -> Callable[[DownloadJobBase], None]: + """Based on the job type select the download method.""" + if isinstance(job, DownloadJobURL): + return self._download_with_resume + elif isinstance(job, DownloadJobPath): + return self._download_path + else: + raise NotImplementedError(f"Don't know what to do with this job: {job}, type={type(job)}") - The metadata will be stashed in job.metadata, if found - Return the download URL. - """ - metadata = job.metadata - url = job.source - metadata_url = url - model = None - - # a Civitai download URL - if match := re.match(CIVITAI_MODEL_DOWNLOAD, str(metadata_url)): - version = match.group(1) - resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json() - metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"] - metadata.description = metadata.description or ( - f"Trigger terms: {(', ').join(resp['trainedWords'])}" if resp["trainedWords"] else resp["description"] - ) - metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}" - - # a Civitai model page with the version - if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, str(metadata_url)): - model = match.group(1) - version = int(match.group(2)) - # and without - elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", str(metadata_url)): - model = match.group(1) - version = None - - if not model: - return parse_obj_as(AnyHttpUrl, url) - - if model: - resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json() - - metadata.author = metadata.author or resp["creator"]["username"] - metadata.tags = metadata.tags or resp["tags"] - metadata.license = ( - metadata.license - or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}" - ) - - if version: - versions = [x for x in resp["modelVersions"] if int(x["id"]) == version] - version_data = versions[0] - else: - version_data = resp["modelVersions"][0] # first one - - metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url - metadata.description = metadata.description or ( - f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}" - if version_data.get("trainedWords") - else version_data.get("description") - ) - - download_url = version_data["downloadUrl"] - - # return the download url - return download_url + def get_url_for_job(self, job: DownloadJobBase) -> AnyHttpUrl: + return job.source def _download_with_resume(self, job: DownloadJobBase): """Do the actual download.""" dest = None try: assert isinstance(job, DownloadJobRemoteSource) - url = self._get_metadata_and_url(job) - + url = self.get_url_for_job(job) header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} open_mode = "wb" exist_size = 0 @@ -497,7 +392,8 @@ class DownloadQueue(DownloadQueueBase): if job.bytes - last_report_bytes >= report_delta: last_report_bytes = job.bytes self._update_job_status(job) - + if job.status != DownloadJobStatus.RUNNING: # cancelled, paused or errored + return self._update_job_status(job, DownloadJobStatus.COMPLETED) except KeyboardInterrupt as excp: raise excp @@ -541,166 +437,6 @@ class DownloadQueue(DownloadQueueBase): job.error = excp self._update_job_status(job, DownloadJobStatus.ERROR) - def _download_repoid(self, job: DownloadJobRepoID): - """Download a job that holds a huggingface repoid.""" - - def subdownload_event(subjob: DownloadJobBase): - assert isinstance(subjob, DownloadJobRemoteSource) - if subjob.status == DownloadJobStatus.RUNNING: - bytes_downloaded[subjob.id] = subjob.bytes - job.bytes = sum(bytes_downloaded.values()) - self._update_job_status(job, DownloadJobStatus.RUNNING) - return - - if subjob.status == DownloadJobStatus.ERROR: - job.error = subjob.error - job.cleanup() - self._update_job_status(job, DownloadJobStatus.ERROR) - return - - if subjob.status == DownloadJobStatus.COMPLETED: - bytes_downloaded[subjob.id] = subjob.bytes - job.bytes = sum(bytes_downloaded.values()) - self._update_job_status(job, DownloadJobStatus.RUNNING) - return - - subqueue = self.__class__( - event_handlers=[subdownload_event], - requests_session=self._requests, - quiet=True, - ) - try: - assert isinstance(job, DownloadJobRepoID) - repo_id = job.source - variant = job.variant - if not job.metadata: - job.metadata = ModelSourceMetadata() - urls_to_download = self._get_repo_info( - repo_id, variant=variant, metadata=job.metadata, subfolder=job.subfolder - ) - if job.destination.name != Path(repo_id).name: - job.destination = job.destination / Path(repo_id).name - bytes_downloaded: Dict[int, int] = dict() - job.total_bytes = 0 - - for url, subdir, file, size in urls_to_download: - job.total_bytes += size - subqueue.create_download_job( - source=url, - destdir=job.destination / subdir, - filename=file, - variant=variant, - access_token=job.access_token, - ) - except KeyboardInterrupt as excp: - raise excp - except Exception as excp: - job.error = excp - self._update_job_status(job, DownloadJobStatus.ERROR) - self._logger.error(job.error) - finally: - job.subqueue = subqueue - job.subqueue.join() - if job.status == DownloadJobStatus.RUNNING: - self._update_job_status(job, DownloadJobStatus.COMPLETED) - job.subqueue.release() # get rid of the subqueue - - def _get_repo_info( - self, - repo_id: str, - metadata: ModelSourceMetadata, - variant: Optional[str] = None, - subfolder: Optional[str] = None, - ) -> List[Tuple[AnyHttpUrl, Path, Path, int]]: - """ - Given a repo_id and an optional variant, return list of URLs to download to get the model. - The metadata field will be updated with model metadata from HuggingFace. - - Known variants currently are: - 1. onnx - 2. openvino - 3. fp16 - 4. None (usually returns fp32 model) - """ - model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True) - sibs = model_info.siblings - paths = [x.rfilename for x in sibs] - sizes = {x.rfilename: x.size for x in sibs} - - prefix = "" - if subfolder: - prefix = f"{subfolder}/" - paths = [x for x in paths if x.startswith(prefix)] - - if f"{prefix}model_index.json" in paths: - url = hf_hub_url(repo_id, filename="model_index.json", subfolder=subfolder) - resp = self._requests.get(url) - resp.raise_for_status() # will raise an HTTPError on non-200 status - submodels = resp.json() - paths = [Path(subfolder or "", x) for x in paths if Path(x).parent.as_posix() in submodels] - paths.insert(0, f"{prefix}model_index.json") - urls = [ - ( - hf_hub_url(repo_id, filename=x.as_posix()), - x.parent.relative_to(prefix) or Path("."), - Path(x.name), - sizes[x.as_posix()], - ) - for x in self._select_variants(paths, variant) - ] - if hasattr(model_info, "cardData"): - metadata.license = metadata.license or model_info.cardData.get("license") - metadata.tags = metadata.tags or model_info.tags - metadata.author = metadata.author or model_info.author - return urls - - def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]: - """Select the proper variant files from a list of HuggingFace repo_id paths.""" - result = set() - basenames: Dict[Path, Path] = dict() - for p in paths: - path = Path(p) - - if path.suffix == ".onnx": - if variant == "onnx": - result.add(path) - - elif path.name.startswith("openvino_model"): - if variant == "openvino": - result.add(path) - - elif path.suffix in [".json", ".txt"]: - result.add(path) - - elif path.suffix in [".bin", ".safetensors", ".pt"] and variant in ["fp16", None]: - parent = path.parent - suffixes = path.suffixes - if len(suffixes) == 2: - file_variant, suffix = suffixes - basename = parent / Path(path.stem).stem - else: - file_variant = None - suffix = suffixes[0] - basename = parent / path.stem - - if previous := basenames.get(basename): - if previous.suffix != ".safetensors" and suffix == ".safetensors": - basenames[basename] = path - if file_variant == f".{variant}": - basenames[basename] = path - elif not variant and not file_variant: - basenames[basename] = path - else: - basenames[basename] = path - - else: - continue - - for v in basenames.values(): - result.add(v) - - return result - def _download_path(self, job: DownloadJobBase): """Call when the source is a Path or pathlike object.""" source = Path(job.source).resolve() diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index df1f4e401a..e8ae798802 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -72,14 +72,20 @@ from .config import ( SchedulerPredictionType, SubModelType, ) -from .download import DownloadEventHandler, DownloadJobBase, DownloadQueue, DownloadQueueBase, ModelSourceMetadata -from .download.queue import ( +from .download import ( + DownloadEventHandler, + DownloadJobBase, + DownloadJobPath, + DownloadJobURL, + DownloadQueueBase, + ModelDownloadQueue, + ModelSourceMetadata, +) +from .download.model_queue import ( HTTP_RE, REPO_ID_WITH_OPTIONAL_SUBFOLDER_RE, - DownloadJobRemoteSource, - DownloadJobPath, DownloadJobRepoID, - DownloadJobURL, + DownloadJobWithMetadata, ) from .hash import FastModelHash from .models import InvalidModelException @@ -88,7 +94,7 @@ from .search import ModelSearch from .storage import DuplicateModelException, ModelConfigStore -class ModelInstallJob(DownloadJobRemoteSource): +class ModelInstallJob(DownloadJobBase): """This is a version of DownloadJobBase that has an additional slot for the model key and probe info.""" model_key: Optional[str] = Field( @@ -100,7 +106,7 @@ class ModelInstallJob(DownloadJobRemoteSource): ) -class ModelInstallURLJob(DownloadJobURL, ModelInstallJob): +class ModelInstallURLJob(DownloadJobWithMetadata, ModelInstallJob): """Job for installing URLs.""" @@ -398,7 +404,7 @@ class ModelInstall(ModelInstallBase): self._app_config = config or InvokeAIAppConfig.get_config() self._logger = logger or InvokeAILogger.get_logger(config=self._app_config) self._store = store or ModelRecordServiceBase.get_impl(self._app_config) - self._download_queue = download or DownloadQueue(config=self._app_config, event_handlers=event_handlers) + self._download_queue = download or ModelDownloadQueue(config=self._app_config, event_handlers=event_handlers) self._async_installs: Dict[Union[str, Path, AnyHttpUrl], Union[str, None]] = dict() self._installed = set() self._tmpdir = None diff --git a/tests/AC_model_manager/test_model_download.py b/tests/AC_model_manager/test_model_download.py index 835f0a8629..3ef7d0c930 100644 --- a/tests/AC_model_manager/test_model_download.py +++ b/tests/AC_model_manager/test_model_download.py @@ -8,11 +8,11 @@ import requests from requests import HTTPError from requests_testadapter import TestAdapter -import invokeai.backend.model_manager.download.queue as download_queue +import invokeai.backend.model_manager.download.model_queue as download_queue from invokeai.backend.model_manager.download import ( DownloadJobBase, DownloadJobStatus, - DownloadQueue, + ModelDownloadQueue, UnknownJobIDException, ) @@ -147,7 +147,7 @@ def test_basic_queue_download(): def event_handler(job: DownloadJobBase): events.append(job.status) - queue = DownloadQueue( + queue = ModelDownloadQueue( requests_session=session, event_handlers=[event_handler], ) @@ -167,7 +167,7 @@ def test_basic_queue_download(): def test_queue_priority(): - queue = DownloadQueue( + queue = ModelDownloadQueue( requests_session=session, ) @@ -200,7 +200,7 @@ def test_repo_id_download(): if not INTERNET_AVAILABLE: return repo_id = "stabilityai/stable-diffusion-2-1" - queue = DownloadQueue( + queue = ModelDownloadQueue( requests_session=session, ) @@ -224,7 +224,7 @@ def test_repo_id_download(): def test_bad_urls(): - queue = DownloadQueue( + queue = ModelDownloadQueue( requests_session=session, ) @@ -270,9 +270,11 @@ def test_bad_urls(): def test_pause_cancel_url(): # this one is tricky because of potential race conditions def event_handler(job: DownloadJobBase): - time.sleep(0.5) # slow down the thread by blocking it just a bit at every step + if job.id == 0: + print(job.status, job.bytes) + time.sleep(0.5) # slow down the thread so that we can recover the paused state - queue = DownloadQueue(requests_session=session, event_handlers=[event_handler]) + queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler]) with tempfile.TemporaryDirectory() as tmpdir: job1 = queue.create_download_job(source="http://www.civitai.com/models/12345", destdir=tmpdir, start=False) job2 = queue.create_download_job(source="http://www.civitai.com/models/9999", destdir=tmpdir, start=False) @@ -322,7 +324,7 @@ def test_pause_cancel_url(): # this one is tricky because of potential race con return repo_id = "stabilityai/stable-diffusion-2-1" - queue = DownloadQueue(requests_session=session, event_handlers=[event_handler]) + queue = ModelDownloadQueue(requests_session=session, event_handlers=[event_handler]) with tempfile.TemporaryDirectory() as tmpdir1, tempfile.TemporaryDirectory() as tmpdir2: job1 = queue.create_download_job(source=repo_id, destdir=tmpdir1, variant="fp16", start=False) diff --git a/tests/AC_model_manager/test_model_install_service.py b/tests/AC_model_manager/test_model_install_service.py index 3087b3d7df..aa8797941c 100644 --- a/tests/AC_model_manager/test_model_install_service.py +++ b/tests/AC_model_manager/test_model_install_service.py @@ -55,7 +55,8 @@ def test_install(datadir: Path): mm_install = ModelInstallService(config=config, store=mm_store, event_bus=event_bus) source = datadir / TEST_MODEL - mm_install.install_model(source=source) + job = mm_install.install_model(source=source) + print(f"DEBUG: job={type(job)}") id_map = mm_install.wait_for_installs() print(id_map) assert source in id_map, "model did not install; id_map empty"