incorporate civitai metadata into model config

This commit is contained in:
Lincoln Stein
2023-09-09 21:17:55 -04:00
parent 3582cfa267
commit b2892f9068
9 changed files with 174 additions and 65 deletions

View File

@ -107,8 +107,8 @@ class ModelConfigBase(BaseModel):
description: Optional[str] = Field(None)
author: Optional[str] = Field(description="Model author")
license: Optional[str] = Field(description="License string")
source: Optional[str] = Field(description="Model download source (URL or repo_id)")
thumbnail_url: Optional[str] = Field(description="URL of thumbnail image")
source_url: Optional[str] = Field(description="Model download source")
tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable
class Config:

View File

@ -6,6 +6,7 @@ from .base import ( # noqa F401
DownloadEventHandler,
UnknownJobIDException,
DownloadJobBase,
ModelSourceMetadata,
)
from .queue import DownloadQueue # noqa F401

View File

@ -9,6 +9,7 @@ from functools import total_ordering
from pathlib import Path
from typing import List, Optional, Callable, Dict, Any
from pydantic import BaseModel, Field
from pydantic.networks import AnyHttpUrl
class DownloadJobStatus(str, Enum):
@ -27,6 +28,16 @@ class UnknownJobIDException(Exception):
"""Raised when an invalid Job is referenced."""
class ModelSourceMetadata(BaseModel):
"""Information collected on a downloadable model from its source site."""
author: Optional[str] = Field(description="Author/creator of the model")
description: Optional[str] = Field(description="Description of the model")
license: Optional[str] = Field(description="Model license terms")
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
tags: Optional[List[str]] = Field(description="List of descriptive tags")
DownloadEventHandler = Callable[["DownloadJobBase"], None]
@ -38,21 +49,30 @@ class DownloadJobBase(BaseModel):
id: int = Field(description="Numeric ID of this job")
source: str = Field(description="URL or repo_id to download")
destination: Path = Field(description="Destination of URL on local disk")
metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None)
access_token: Optional[str] = Field(description="access token needed to access this resource")
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
bytes: int = Field(default=0, description="Bytes downloaded so far")
total_bytes: int = Field(default=0, description="Total bytes to download")
event_handlers: Optional[List[DownloadEventHandler]] = Field(
description="Callables that will be called whenever job status changes"
description="Callables that will be called whenever job status changes",
default_factory=list,
)
job_started: Optional[float] = Field(description="Timestamp for when the download job started")
job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)")
job_sequence: Optional[int] = Field(
description="Counter that records order in which this job was dequeued (for debugging)"
)
metadata: Dict[str, Any] = Field(default_factory=dict, description="Model metadata (source-specific)")
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
def add_event_handler(self, handler: DownloadEventHandler):
"""Add an event handler to the end of the handlers list."""
self.event_handlers.append(handler)
def clear_event_handlers(self):
"""Clear all event handlers."""
self.event_handlers = list()
class Config:
"""Config object for this pydantic class."""

View File

@ -7,7 +7,9 @@ import requests
import shutil
import threading
import time
import traceback
from json import JSONDecodeError
from pathlib import Path
from requests import HTTPError
from typing import Dict, Optional, Set, List, Tuple
@ -19,17 +21,26 @@ from queue import PriorityQueue
from huggingface_hub import HfApi, hf_hub_url
from invokeai.backend.util.logging import InvokeAILogger
from invokeai.app.services.config import InvokeAIAppConfig
from .base import (
DownloadQueueBase,
DownloadJobStatus,
DownloadEventHandler,
UnknownJobIDException,
DownloadJobBase,
ModelSourceMetadata,
)
from ..storage import DuplicateModelException
# marker that the queue is done and that thread should exit
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
# endpoint for civitai get-model API
CIVITAI_MODEL_DOWNLOAD = "https://civitai.com/api/download/models/"
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
class DownloadJobURL(DownloadJobBase):
"""Job declaration for downloading individual URLs."""
@ -68,6 +79,7 @@ class DownloadQueue(DownloadQueueBase):
max_parallel_dl: int = 5,
event_handlers: Optional[List[DownloadEventHandler]] = None,
requests_session: Optional[requests.sessions.Session] = None,
config: Optional[InvokeAIAppConfig] = None,
):
"""
Initialize DownloadQueue.
@ -81,7 +93,7 @@ class DownloadQueue(DownloadQueueBase):
self._queue = PriorityQueue()
self._worker_pool = set()
self._lock = threading.RLock()
self._logger = InvokeAILogger.getLogger()
self._logger = InvokeAILogger.getLogger(config=config)
self._event_handlers = event_handlers
self._requests = requests_session or requests.Session()
@ -269,43 +281,91 @@ class DownloadQueue(DownloadQueueBase):
if self._in_terminal_state(job):
del self._jobs[job.id]
if job.status == "error":
self._logger.warning(f"{job.source}: Download finished with error: {job.error}")
else:
self._logger.info(f"{job.source}: Download finished with status {job.status}")
self._queue.task_done()
def _fetch_metadata(self, job: DownloadJobBase) -> Tuple[AnyHttpUrl, ModelSourceMetadata]:
"""
Fetch metadata from certain well-known URLs.
The metadata will be stashed in job.metadata, if found
Return the download URL.
"""
metadata = ModelSourceMetadata()
url = job.source
metadata_url = url
try:
# a Civitai download URL
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = resp["images"][0]["url"]
metadata.description = (
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
if resp["trainedWords"]
else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"])
# a Civitai model page
if match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url):
model = match.group(1)
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
# note that we munge the URL here to get the download URL of the first model
url = resp["modelVersions"][0]["downloadUrl"]
metadata.author = resp["creator"]["username"]
metadata.tags = resp["tags"]
metadata.thumbnail_url = resp["modelVersions"][0]["images"][0]["url"]
metadata.license = f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
self._logger.warn(excp)
# update metadata and return the download url
return url, metadata
def _download_with_resume(self, job: DownloadJobBase):
"""Do the actual download."""
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
exist_size = 0
resp = self._requests.get(job.source, headers=header, stream=True)
content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length
if job.destination.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1)
self._validate_filename(
job.destination, file_name
) # will raise a ValueError exception if file_name is suspicious
except ValueError:
self._logger.warning(
f"Invalid filename '{file_name}' returned by source {job.source}, using last component of URL instead"
)
file_name = os.path.basename(job.source)
except KeyError:
file_name = os.path.basename(job.source)
job.destination = job.destination / file_name
dest = job.destination
else:
dest = job.destination
dest.parent.mkdir(parents=True, exist_ok=True)
try:
url, metadata = self._fetch_metadata(job)
job.metadata = metadata
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
exist_size = 0
resp = self._requests.get(url, headers=header, stream=True)
content_length = int(resp.headers.get("content-length", 0))
job.total_bytes = content_length
if job.destination.is_dir():
try:
file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1)
self._validate_filename(
job.destination, file_name
) # will raise a ValueError exception if file_name is suspicious
except ValueError:
self._logger.warning(
f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead"
)
file_name = os.path.basename(url)
except KeyError:
file_name = os.path.basename(url)
job.destination = job.destination / file_name
dest = job.destination
else:
dest = job.destination
dest.parent.mkdir(parents=True, exist_ok=True)
if dest.exists():
job.bytes = dest.stat().st_size
header["Range"] = f"bytes={job.bytes}-"
open_mode = "ab"
resp = self._requests.get(job.source, headers=header, stream=True) # new request with range
resp = self._requests.get(url, headers=header, stream=True) # new request with range
if exist_size > content_length:
self._logger.warning("corrupt existing file found. re-downloading")
@ -318,11 +378,11 @@ class DownloadQueue(DownloadQueueBase):
return
if resp.status_code == 206 or exist_size > 0:
self._logger.warning(f"{dest}: partial file found. Resuming...")
self._logger.warning(f"{dest}: partial file found. Resuming")
elif resp.status_code != 200:
raise HTTPError(resp.reason)
else:
self._logger.info(f"{dest}: Downloading...")
self._logger.info(f"{job.source}: Downloading {job.destination}")
report_delta = job.total_bytes / 100 # report every 1% change
last_report_bytes = 0
@ -338,8 +398,13 @@ class DownloadQueue(DownloadQueueBase):
self._update_job_status(job)
self._update_job_status(job, DownloadJobStatus.COMPLETED)
except DuplicateModelException as excp:
self._logger.error(f"A model with the same hash as {dest} is already installed.")
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
except Exception as excp:
self._logger.error(f"An error occurred while downloading {dest}: {str(excp)}")
self._logger.error(f"An error occurred while downloading/installing {dest}: {str(excp)}")
traceback.print_exception(excp)
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
@ -363,8 +428,12 @@ class DownloadQueue(DownloadQueueBase):
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
job.job_ended = time.time()
if job.event_handlers:
for handler in job.event_handlers:
handler(job)
try:
for handler in job.event_handlers:
handler(job)
except Exception as excp:
job.status = DownloadJobStatus.ERROR
job.error = excp
def _download_repoid(self, job: DownloadJobBase):
"""Download a job that holds a huggingface repoid."""
@ -410,8 +479,8 @@ class DownloadQueue(DownloadQueueBase):
access_token=job.access_token,
)
except Exception as excp:
job.status = DownloadJobStatus.ERROR
job.error = excp
self._update_job_status(job, DownloadJobStatus.ERROR)
self._logger.error(job.error)
finally:
subqueue.join()
@ -423,7 +492,7 @@ class DownloadQueue(DownloadQueueBase):
self,
repo_id: str,
variant: Optional[str] = None,
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]:
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]:
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
sibs = model_info.siblings
@ -440,7 +509,12 @@ class DownloadQueue(DownloadQueueBase):
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
for x in self._select_variants(paths, variant)
]
return (urls, {"cardData": model_info.cardData, "tags": model_info.tags, "author": model_info.author})
return (
urls,
ModelSourceMetadata(
license=model_info.cardData.get("license"), tags=model_info.tags, author=model_info.author
),
)
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
"""Select the proper variant files from a list of HuggingFace repo_id paths."""

View File

@ -57,7 +57,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .search import ModelSearch
from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata
from .hash import FastModelHash
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
from .config import (
@ -250,9 +250,9 @@ class ModelInstall(ModelInstallBase):
download: Optional[DownloadQueueBase] = None,
): # noqa D107 - use base class docstrings
self._config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.getLogger()
self._logger = logger or InvokeAILogger.getLogger(config=self._config)
self._store = store or ModelConfigStoreYAML(self._config.model_conf_path)
self._download_queue = download or DownloadQueue()
self._download_queue = download or DownloadQueue(config=self._config)
self._async_installs = dict()
self._tmpdir = None
@ -274,6 +274,14 @@ class ModelInstall(ModelInstallBase):
if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint:
try:
config_file = self._legacy_configs[info.base_type][info.variant_type]
if isinstance(config_file, dict): # need another tier for sd-2.x models
if prediction_type := info.prediction_type:
config_file = config_file[prediction_type]
else:
self._logger.warning(
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
)
config_file = config_file[SchedulerPredictionType.VPrediction]
except KeyError as exc:
raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc
registration_data.update(
@ -314,20 +322,18 @@ class ModelInstall(ModelInstallBase):
self._async_installs[source] = None
def complete_installation(job: DownloadJobBase):
self._logger.info(f"{job.source}: {job.status} filename={job.destination}({job.bytes}/{job.total_bytes})")
if job.status == "completed":
id = self.install(job.destination)
info = self._store.get_model(id)
info.description = f"Downloaded model {info.name}"
info.source_url = str(job.source)
if card_data := job.metadata.get("cardData"):
info.license = card_data.get("license")
if author := job.metadata.get("author"):
info.author = author
if tags := job.metadata.get("tags"):
info.tags = tags
self._store.update_model(id, info)
self._async_installs[job.source] = id
model_id = self.install(job.destination)
info = self._store.get_model(model_id)
info.source = str(job.source)
metadata: ModelSourceMetadata = job.metadata
info.description = metadata.description or f"Downloaded model {info.name}"
info.author = metadata.author
info.tags = metadata.tags
info.license = metadata.license
info.thumbnail_url = metadata.thumbnail_url
self._store.update_model(model_id, info)
self._async_installs[job.source] = model_id
jobs = queue.list_jobs()
if len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
self._tmpdir = None
@ -336,9 +342,8 @@ class ModelInstall(ModelInstallBase):
# will be deleted before the job actually runs.
# Better to do the cleanup in the callback
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
return queue.create_download_job(
source=source, destdir=self._tmpdir.name, event_handlers=[complete_installation]
)
job = queue.create_download_job(source=source, destdir=self._tmpdir.name)
job.add_event_handler(complete_installation)
def wait_for_downloads(self) -> Dict[str, str]: # noqa D102
self._download_queue.join()

View File

@ -206,7 +206,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
except sqlite3.IntegrityError as e:
self._conn.rollback()
if "UNIQUE constraint failed" in str(e):
raise DuplicateModelException from e
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
else:
raise e
except sqlite3.Error as e:

View File

@ -116,7 +116,10 @@ class ModelConfigStoreYAML(ModelConfigStore):
try:
self._lock.acquire()
if key in self._config:
raise DuplicateModelException(f"Duplicate key {key} for model named '{record.name}'")
existing_model = self.get_model(key)
raise DuplicateModelException(
f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'"
)
self._config[key] = dict_fields
self._commit()
finally:

View File

@ -181,6 +181,7 @@ import urllib.parse
from abc import abstractmethod
from pathlib import Path
from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
@ -352,9 +353,8 @@ class InvokeAILogger(object):
loggers = dict()
@classmethod
def getLogger(
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
) -> logging.Logger:
def getLogger(cls, name: str = "InvokeAI", config: Optional[InvokeAIAppConfig] = None) -> logging.Logger:
config = config or InvokeAIAppConfig.get_config()
if name in cls.loggers:
logger = cls.loggers[name]
logger.handlers.clear()

View File

@ -54,8 +54,14 @@ session.mount(
},
),
)
# not found
# not found test
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
# prevent us from going to civitai to get metadata
session.mount("https://civitai.com/api/download/models/", TestAdapter(b"Not found", status=404))
session.mount("https://civitai.com/api/v1/models/", TestAdapter(b"Not found", status=404))
session.mount("https://civitai.com/api/v1/model-versions/", TestAdapter(b"Not found", status=404))
# specifies a content disposition that may overwrite files in the parent directory
session.mount(
"http://www.civitai.com/models/malicious",