diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 7ebcb70e07..1f4186732f 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -107,8 +107,8 @@ class ModelConfigBase(BaseModel): description: Optional[str] = Field(None) author: Optional[str] = Field(description="Model author") license: Optional[str] = Field(description="License string") + source: Optional[str] = Field(description="Model download source (URL or repo_id)") thumbnail_url: Optional[str] = Field(description="URL of thumbnail image") - source_url: Optional[str] = Field(description="Model download source") tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable class Config: diff --git a/invokeai/backend/model_manager/download/__init__.py b/invokeai/backend/model_manager/download/__init__.py index dec6d87e87..59bd617102 100644 --- a/invokeai/backend/model_manager/download/__init__.py +++ b/invokeai/backend/model_manager/download/__init__.py @@ -6,6 +6,7 @@ from .base import ( # noqa F401 DownloadEventHandler, UnknownJobIDException, DownloadJobBase, + ModelSourceMetadata, ) from .queue import DownloadQueue # noqa F401 diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 9684a0497c..c6bffe3cc2 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -9,6 +9,7 @@ from functools import total_ordering from pathlib import Path from typing import List, Optional, Callable, Dict, Any from pydantic import BaseModel, Field +from pydantic.networks import AnyHttpUrl class DownloadJobStatus(str, Enum): @@ -27,6 +28,16 @@ class UnknownJobIDException(Exception): """Raised when an invalid Job is referenced.""" +class ModelSourceMetadata(BaseModel): + """Information collected on a downloadable model from its source site.""" + + author: Optional[str] = Field(description="Author/creator of the model") + description: Optional[str] = Field(description="Description of the model") + license: Optional[str] = Field(description="Model license terms") + thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model") + tags: Optional[List[str]] = Field(description="List of descriptive tags") + + DownloadEventHandler = Callable[["DownloadJobBase"], None] @@ -38,21 +49,30 @@ class DownloadJobBase(BaseModel): id: int = Field(description="Numeric ID of this job") source: str = Field(description="URL or repo_id to download") destination: Path = Field(description="Destination of URL on local disk") + metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None) access_token: Optional[str] = Field(description="access token needed to access this resource") status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download") bytes: int = Field(default=0, description="Bytes downloaded so far") total_bytes: int = Field(default=0, description="Total bytes to download") event_handlers: Optional[List[DownloadEventHandler]] = Field( - description="Callables that will be called whenever job status changes" + description="Callables that will be called whenever job status changes", + default_factory=list, ) job_started: Optional[float] = Field(description="Timestamp for when the download job started") job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)") job_sequence: Optional[int] = Field( description="Counter that records order in which this job was dequeued (for debugging)" ) - metadata: Dict[str, Any] = Field(default_factory=dict, description="Model metadata (source-specific)") error: Optional[Exception] = Field(default=None, description="Exception that caused an error") + def add_event_handler(self, handler: DownloadEventHandler): + """Add an event handler to the end of the handlers list.""" + self.event_handlers.append(handler) + + def clear_event_handlers(self): + """Clear all event handlers.""" + self.event_handlers = list() + class Config: """Config object for this pydantic class.""" diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index 1bd01a8760..0767e27ea3 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -7,7 +7,9 @@ import requests import shutil import threading import time +import traceback +from json import JSONDecodeError from pathlib import Path from requests import HTTPError from typing import Dict, Optional, Set, List, Tuple @@ -19,17 +21,26 @@ from queue import PriorityQueue from huggingface_hub import HfApi, hf_hub_url from invokeai.backend.util.logging import InvokeAILogger +from invokeai.app.services.config import InvokeAIAppConfig from .base import ( DownloadQueueBase, DownloadJobStatus, DownloadEventHandler, UnknownJobIDException, DownloadJobBase, + ModelSourceMetadata, ) +from ..storage import DuplicateModelException # marker that the queue is done and that thread should exit STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/") +# endpoint for civitai get-model API +CIVITAI_MODEL_DOWNLOAD = "https://civitai.com/api/download/models/" +CIVITAI_MODEL_PAGE = "https://civitai.com/models/" +CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/" +CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/" + class DownloadJobURL(DownloadJobBase): """Job declaration for downloading individual URLs.""" @@ -68,6 +79,7 @@ class DownloadQueue(DownloadQueueBase): max_parallel_dl: int = 5, event_handlers: Optional[List[DownloadEventHandler]] = None, requests_session: Optional[requests.sessions.Session] = None, + config: Optional[InvokeAIAppConfig] = None, ): """ Initialize DownloadQueue. @@ -81,7 +93,7 @@ class DownloadQueue(DownloadQueueBase): self._queue = PriorityQueue() self._worker_pool = set() self._lock = threading.RLock() - self._logger = InvokeAILogger.getLogger() + self._logger = InvokeAILogger.getLogger(config=config) self._event_handlers = event_handlers self._requests = requests_session or requests.Session() @@ -269,43 +281,91 @@ class DownloadQueue(DownloadQueueBase): if self._in_terminal_state(job): del self._jobs[job.id] + if job.status == "error": + self._logger.warning(f"{job.source}: Download finished with error: {job.error}") + else: + self._logger.info(f"{job.source}: Download finished with status {job.status}") self._queue.task_done() + def _fetch_metadata(self, job: DownloadJobBase) -> Tuple[AnyHttpUrl, ModelSourceMetadata]: + """ + Fetch metadata from certain well-known URLs. + + The metadata will be stashed in job.metadata, if found + Return the download URL. + """ + metadata = ModelSourceMetadata() + url = job.source + metadata_url = url + try: + # a Civitai download URL + if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url): + version = match.group(1) + resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json() + metadata.thumbnail_url = resp["images"][0]["url"] + metadata.description = ( + f"Trigger terms: {(', ').join(resp['trainedWords'])}" + if resp["trainedWords"] + else resp["description"] + ) + metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + + # a Civitai model page + if match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url): + model = match.group(1) + resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json() + + # note that we munge the URL here to get the download URL of the first model + url = resp["modelVersions"][0]["downloadUrl"] + + metadata.author = resp["creator"]["username"] + metadata.tags = resp["tags"] + metadata.thumbnail_url = resp["modelVersions"][0]["images"][0]["url"] + metadata.license = f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}" + except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp: + self._logger.warn(excp) + + # update metadata and return the download url + return url, metadata + def _download_with_resume(self, job: DownloadJobBase): """Do the actual download.""" - header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} - open_mode = "wb" - exist_size = 0 - - resp = self._requests.get(job.source, headers=header, stream=True) - content_length = int(resp.headers.get("content-length", 0)) - job.total_bytes = content_length - - if job.destination.is_dir(): - try: - file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1) - self._validate_filename( - job.destination, file_name - ) # will raise a ValueError exception if file_name is suspicious - except ValueError: - self._logger.warning( - f"Invalid filename '{file_name}' returned by source {job.source}, using last component of URL instead" - ) - file_name = os.path.basename(job.source) - except KeyError: - file_name = os.path.basename(job.source) - job.destination = job.destination / file_name - dest = job.destination - else: - dest = job.destination - dest.parent.mkdir(parents=True, exist_ok=True) - try: + url, metadata = self._fetch_metadata(job) + job.metadata = metadata + + header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} + open_mode = "wb" + exist_size = 0 + + resp = self._requests.get(url, headers=header, stream=True) + content_length = int(resp.headers.get("content-length", 0)) + job.total_bytes = content_length + + if job.destination.is_dir(): + try: + file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1) + self._validate_filename( + job.destination, file_name + ) # will raise a ValueError exception if file_name is suspicious + except ValueError: + self._logger.warning( + f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead" + ) + file_name = os.path.basename(url) + except KeyError: + file_name = os.path.basename(url) + job.destination = job.destination / file_name + dest = job.destination + else: + dest = job.destination + dest.parent.mkdir(parents=True, exist_ok=True) + if dest.exists(): job.bytes = dest.stat().st_size header["Range"] = f"bytes={job.bytes}-" open_mode = "ab" - resp = self._requests.get(job.source, headers=header, stream=True) # new request with range + resp = self._requests.get(url, headers=header, stream=True) # new request with range if exist_size > content_length: self._logger.warning("corrupt existing file found. re-downloading") @@ -318,11 +378,11 @@ class DownloadQueue(DownloadQueueBase): return if resp.status_code == 206 or exist_size > 0: - self._logger.warning(f"{dest}: partial file found. Resuming...") + self._logger.warning(f"{dest}: partial file found. Resuming") elif resp.status_code != 200: raise HTTPError(resp.reason) else: - self._logger.info(f"{dest}: Downloading...") + self._logger.info(f"{job.source}: Downloading {job.destination}") report_delta = job.total_bytes / 100 # report every 1% change last_report_bytes = 0 @@ -338,8 +398,13 @@ class DownloadQueue(DownloadQueueBase): self._update_job_status(job) self._update_job_status(job, DownloadJobStatus.COMPLETED) + except DuplicateModelException as excp: + self._logger.error(f"A model with the same hash as {dest} is already installed.") + job.error = excp + self._update_job_status(job, DownloadJobStatus.ERROR) except Exception as excp: - self._logger.error(f"An error occurred while downloading {dest}: {str(excp)}") + self._logger.error(f"An error occurred while downloading/installing {dest}: {str(excp)}") + traceback.print_exception(excp) job.error = excp self._update_job_status(job, DownloadJobStatus.ERROR) @@ -363,8 +428,12 @@ class DownloadQueue(DownloadQueueBase): elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]: job.job_ended = time.time() if job.event_handlers: - for handler in job.event_handlers: - handler(job) + try: + for handler in job.event_handlers: + handler(job) + except Exception as excp: + job.status = DownloadJobStatus.ERROR + job.error = excp def _download_repoid(self, job: DownloadJobBase): """Download a job that holds a huggingface repoid.""" @@ -410,8 +479,8 @@ class DownloadQueue(DownloadQueueBase): access_token=job.access_token, ) except Exception as excp: - job.status = DownloadJobStatus.ERROR job.error = excp + self._update_job_status(job, DownloadJobStatus.ERROR) self._logger.error(job.error) finally: subqueue.join() @@ -423,7 +492,7 @@ class DownloadQueue(DownloadQueueBase): self, repo_id: str, variant: Optional[str] = None, - ) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]: + ) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]: """Given a repo_id and an optional variant, return list of URLs to download to get the model.""" model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True) sibs = model_info.siblings @@ -440,7 +509,12 @@ class DownloadQueue(DownloadQueueBase): (hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()]) for x in self._select_variants(paths, variant) ] - return (urls, {"cardData": model_info.cardData, "tags": model_info.tags, "author": model_info.author}) + return ( + urls, + ModelSourceMetadata( + license=model_info.cardData.get("license"), tags=model_info.tags, author=model_info.author + ), + ) def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]: """Select the proper variant files from a list of HuggingFace repo_id paths.""" diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index a7ac807c83..2494d2e812 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -57,7 +57,7 @@ from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger from .search import ModelSearch from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException -from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase +from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata from .hash import FastModelHash from .probe import ModelProbe, ModelProbeInfo, InvalidModelException from .config import ( @@ -250,9 +250,9 @@ class ModelInstall(ModelInstallBase): download: Optional[DownloadQueueBase] = None, ): # noqa D107 - use base class docstrings self._config = config or InvokeAIAppConfig.get_config() - self._logger = logger or InvokeAILogger.getLogger() + self._logger = logger or InvokeAILogger.getLogger(config=self._config) self._store = store or ModelConfigStoreYAML(self._config.model_conf_path) - self._download_queue = download or DownloadQueue() + self._download_queue = download or DownloadQueue(config=self._config) self._async_installs = dict() self._tmpdir = None @@ -274,6 +274,14 @@ class ModelInstall(ModelInstallBase): if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint: try: config_file = self._legacy_configs[info.base_type][info.variant_type] + if isinstance(config_file, dict): # need another tier for sd-2.x models + if prediction_type := info.prediction_type: + config_file = config_file[prediction_type] + else: + self._logger.warning( + f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model" + ) + config_file = config_file[SchedulerPredictionType.VPrediction] except KeyError as exc: raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc registration_data.update( @@ -314,20 +322,18 @@ class ModelInstall(ModelInstallBase): self._async_installs[source] = None def complete_installation(job: DownloadJobBase): - self._logger.info(f"{job.source}: {job.status} filename={job.destination}({job.bytes}/{job.total_bytes})") if job.status == "completed": - id = self.install(job.destination) - info = self._store.get_model(id) - info.description = f"Downloaded model {info.name}" - info.source_url = str(job.source) - if card_data := job.metadata.get("cardData"): - info.license = card_data.get("license") - if author := job.metadata.get("author"): - info.author = author - if tags := job.metadata.get("tags"): - info.tags = tags - self._store.update_model(id, info) - self._async_installs[job.source] = id + model_id = self.install(job.destination) + info = self._store.get_model(model_id) + info.source = str(job.source) + metadata: ModelSourceMetadata = job.metadata + info.description = metadata.description or f"Downloaded model {info.name}" + info.author = metadata.author + info.tags = metadata.tags + info.license = metadata.license + info.thumbnail_url = metadata.thumbnail_url + self._store.update_model(model_id, info) + self._async_installs[job.source] = model_id jobs = queue.list_jobs() if len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]: self._tmpdir = None @@ -336,9 +342,8 @@ class ModelInstall(ModelInstallBase): # will be deleted before the job actually runs. # Better to do the cleanup in the callback self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) - return queue.create_download_job( - source=source, destdir=self._tmpdir.name, event_handlers=[complete_installation] - ) + job = queue.create_download_job(source=source, destdir=self._tmpdir.name) + job.add_event_handler(complete_installation) def wait_for_downloads(self) -> Dict[str, str]: # noqa D102 self._download_queue.join() diff --git a/invokeai/backend/model_manager/storage/sql.py b/invokeai/backend/model_manager/storage/sql.py index 4ef1652a7a..9f58e8286b 100644 --- a/invokeai/backend/model_manager/storage/sql.py +++ b/invokeai/backend/model_manager/storage/sql.py @@ -206,7 +206,7 @@ class ModelConfigStoreSQL(ModelConfigStore): except sqlite3.IntegrityError as e: self._conn.rollback() if "UNIQUE constraint failed" in str(e): - raise DuplicateModelException from e + raise DuplicateModelException(f"A model with key '{key}' is already installed") from e else: raise e except sqlite3.Error as e: diff --git a/invokeai/backend/model_manager/storage/yaml.py b/invokeai/backend/model_manager/storage/yaml.py index 728e405567..66098c0f5d 100644 --- a/invokeai/backend/model_manager/storage/yaml.py +++ b/invokeai/backend/model_manager/storage/yaml.py @@ -116,7 +116,10 @@ class ModelConfigStoreYAML(ModelConfigStore): try: self._lock.acquire() if key in self._config: - raise DuplicateModelException(f"Duplicate key {key} for model named '{record.name}'") + existing_model = self.get_model(key) + raise DuplicateModelException( + f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'" + ) self._config[key] = dict_fields self._commit() finally: diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 82706d8181..92e85b4c52 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -181,6 +181,7 @@ import urllib.parse from abc import abstractmethod from pathlib import Path +from typing import Optional from invokeai.app.services.config import InvokeAIAppConfig @@ -352,9 +353,8 @@ class InvokeAILogger(object): loggers = dict() @classmethod - def getLogger( - cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config() - ) -> logging.Logger: + def getLogger(cls, name: str = "InvokeAI", config: Optional[InvokeAIAppConfig] = None) -> logging.Logger: + config = config or InvokeAIAppConfig.get_config() if name in cls.loggers: logger = cls.loggers[name] logger.handlers.clear() diff --git a/tests/test_model_download.py b/tests/test_model_download.py index b599501616..7ebbad6cc3 100644 --- a/tests/test_model_download.py +++ b/tests/test_model_download.py @@ -54,8 +54,14 @@ session.mount( }, ), ) -# not found +# not found test session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404)) + +# prevent us from going to civitai to get metadata +session.mount("https://civitai.com/api/download/models/", TestAdapter(b"Not found", status=404)) +session.mount("https://civitai.com/api/v1/models/", TestAdapter(b"Not found", status=404)) +session.mount("https://civitai.com/api/v1/model-versions/", TestAdapter(b"Not found", status=404)) + # specifies a content disposition that may overwrite files in the parent directory session.mount( "http://www.civitai.com/models/malicious",