mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
incorporate civitai metadata into model config
This commit is contained in:
@ -107,8 +107,8 @@ class ModelConfigBase(BaseModel):
|
||||
description: Optional[str] = Field(None)
|
||||
author: Optional[str] = Field(description="Model author")
|
||||
license: Optional[str] = Field(description="License string")
|
||||
source: Optional[str] = Field(description="Model download source (URL or repo_id)")
|
||||
thumbnail_url: Optional[str] = Field(description="URL of thumbnail image")
|
||||
source_url: Optional[str] = Field(description="Model download source")
|
||||
tags: Optional[List[str]] = Field(description="Descriptive tags") # Set would be better, but not JSON serializable
|
||||
|
||||
class Config:
|
||||
|
@ -6,6 +6,7 @@ from .base import ( # noqa F401
|
||||
DownloadEventHandler,
|
||||
UnknownJobIDException,
|
||||
DownloadJobBase,
|
||||
ModelSourceMetadata,
|
||||
)
|
||||
|
||||
from .queue import DownloadQueue # noqa F401
|
||||
|
@ -9,6 +9,7 @@ from functools import total_ordering
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Callable, Dict, Any
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
|
||||
|
||||
class DownloadJobStatus(str, Enum):
|
||||
@ -27,6 +28,16 @@ class UnknownJobIDException(Exception):
|
||||
"""Raised when an invalid Job is referenced."""
|
||||
|
||||
|
||||
class ModelSourceMetadata(BaseModel):
|
||||
"""Information collected on a downloadable model from its source site."""
|
||||
|
||||
author: Optional[str] = Field(description="Author/creator of the model")
|
||||
description: Optional[str] = Field(description="Description of the model")
|
||||
license: Optional[str] = Field(description="Model license terms")
|
||||
thumbnail_url: Optional[AnyHttpUrl] = Field(description="URL of a thumbnail image for the model")
|
||||
tags: Optional[List[str]] = Field(description="List of descriptive tags")
|
||||
|
||||
|
||||
DownloadEventHandler = Callable[["DownloadJobBase"], None]
|
||||
|
||||
|
||||
@ -38,21 +49,30 @@ class DownloadJobBase(BaseModel):
|
||||
id: int = Field(description="Numeric ID of this job")
|
||||
source: str = Field(description="URL or repo_id to download")
|
||||
destination: Path = Field(description="Destination of URL on local disk")
|
||||
metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None)
|
||||
access_token: Optional[str] = Field(description="access token needed to access this resource")
|
||||
status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download")
|
||||
bytes: int = Field(default=0, description="Bytes downloaded so far")
|
||||
total_bytes: int = Field(default=0, description="Total bytes to download")
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = Field(
|
||||
description="Callables that will be called whenever job status changes"
|
||||
description="Callables that will be called whenever job status changes",
|
||||
default_factory=list,
|
||||
)
|
||||
job_started: Optional[float] = Field(description="Timestamp for when the download job started")
|
||||
job_ended: Optional[float] = Field(description="Timestamp for when the download job ended (completed or errored)")
|
||||
job_sequence: Optional[int] = Field(
|
||||
description="Counter that records order in which this job was dequeued (for debugging)"
|
||||
)
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict, description="Model metadata (source-specific)")
|
||||
error: Optional[Exception] = Field(default=None, description="Exception that caused an error")
|
||||
|
||||
def add_event_handler(self, handler: DownloadEventHandler):
|
||||
"""Add an event handler to the end of the handlers list."""
|
||||
self.event_handlers.append(handler)
|
||||
|
||||
def clear_event_handlers(self):
|
||||
"""Clear all event handlers."""
|
||||
self.event_handlers = list()
|
||||
|
||||
class Config:
|
||||
"""Config object for this pydantic class."""
|
||||
|
||||
|
@ -7,7 +7,9 @@ import requests
|
||||
import shutil
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from requests import HTTPError
|
||||
from typing import Dict, Optional, Set, List, Tuple
|
||||
@ -19,17 +21,26 @@ from queue import PriorityQueue
|
||||
from huggingface_hub import HfApi, hf_hub_url
|
||||
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from .base import (
|
||||
DownloadQueueBase,
|
||||
DownloadJobStatus,
|
||||
DownloadEventHandler,
|
||||
UnknownJobIDException,
|
||||
DownloadJobBase,
|
||||
ModelSourceMetadata,
|
||||
)
|
||||
from ..storage import DuplicateModelException
|
||||
|
||||
# marker that the queue is done and that thread should exit
|
||||
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
|
||||
|
||||
# endpoint for civitai get-model API
|
||||
CIVITAI_MODEL_DOWNLOAD = "https://civitai.com/api/download/models/"
|
||||
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
|
||||
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
|
||||
|
||||
class DownloadJobURL(DownloadJobBase):
|
||||
"""Job declaration for downloading individual URLs."""
|
||||
@ -68,6 +79,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
max_parallel_dl: int = 5,
|
||||
event_handlers: Optional[List[DownloadEventHandler]] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
@ -81,7 +93,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._queue = PriorityQueue()
|
||||
self._worker_pool = set()
|
||||
self._lock = threading.RLock()
|
||||
self._logger = InvokeAILogger.getLogger()
|
||||
self._logger = InvokeAILogger.getLogger(config=config)
|
||||
self._event_handlers = event_handlers
|
||||
self._requests = requests_session or requests.Session()
|
||||
|
||||
@ -269,43 +281,91 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if self._in_terminal_state(job):
|
||||
del self._jobs[job.id]
|
||||
|
||||
if job.status == "error":
|
||||
self._logger.warning(f"{job.source}: Download finished with error: {job.error}")
|
||||
else:
|
||||
self._logger.info(f"{job.source}: Download finished with status {job.status}")
|
||||
self._queue.task_done()
|
||||
|
||||
def _fetch_metadata(self, job: DownloadJobBase) -> Tuple[AnyHttpUrl, ModelSourceMetadata]:
|
||||
"""
|
||||
Fetch metadata from certain well-known URLs.
|
||||
|
||||
The metadata will be stashed in job.metadata, if found
|
||||
Return the download URL.
|
||||
"""
|
||||
metadata = ModelSourceMetadata()
|
||||
url = job.source
|
||||
metadata_url = url
|
||||
try:
|
||||
# a Civitai download URL
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
|
||||
version = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
||||
metadata.thumbnail_url = resp["images"][0]["url"]
|
||||
metadata.description = (
|
||||
f"Trigger terms: {(', ').join(resp['trainedWords'])}"
|
||||
if resp["trainedWords"]
|
||||
else resp["description"]
|
||||
)
|
||||
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"])
|
||||
|
||||
# a Civitai model page
|
||||
if match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url):
|
||||
model = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
|
||||
|
||||
# note that we munge the URL here to get the download URL of the first model
|
||||
url = resp["modelVersions"][0]["downloadUrl"]
|
||||
|
||||
metadata.author = resp["creator"]["username"]
|
||||
metadata.tags = resp["tags"]
|
||||
metadata.thumbnail_url = resp["modelVersions"][0]["images"][0]["url"]
|
||||
metadata.license = f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
||||
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
|
||||
self._logger.warn(excp)
|
||||
|
||||
# update metadata and return the download url
|
||||
return url, metadata
|
||||
|
||||
def _download_with_resume(self, job: DownloadJobBase):
|
||||
"""Do the actual download."""
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
resp = self._requests.get(job.source, headers=header, stream=True)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
job.total_bytes = content_length
|
||||
|
||||
if job.destination.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1)
|
||||
self._validate_filename(
|
||||
job.destination, file_name
|
||||
) # will raise a ValueError exception if file_name is suspicious
|
||||
except ValueError:
|
||||
self._logger.warning(
|
||||
f"Invalid filename '{file_name}' returned by source {job.source}, using last component of URL instead"
|
||||
)
|
||||
file_name = os.path.basename(job.source)
|
||||
except KeyError:
|
||||
file_name = os.path.basename(job.source)
|
||||
job.destination = job.destination / file_name
|
||||
dest = job.destination
|
||||
else:
|
||||
dest = job.destination
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
url, metadata = self._fetch_metadata(job)
|
||||
job.metadata = metadata
|
||||
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
resp = self._requests.get(url, headers=header, stream=True)
|
||||
content_length = int(resp.headers.get("content-length", 0))
|
||||
job.total_bytes = content_length
|
||||
|
||||
if job.destination.is_dir():
|
||||
try:
|
||||
file_name = re.search('filename="(.+)"', resp.headers["Content-Disposition"]).group(1)
|
||||
self._validate_filename(
|
||||
job.destination, file_name
|
||||
) # will raise a ValueError exception if file_name is suspicious
|
||||
except ValueError:
|
||||
self._logger.warning(
|
||||
f"Invalid filename '{file_name}' returned by source {url}, using last component of URL instead"
|
||||
)
|
||||
file_name = os.path.basename(url)
|
||||
except KeyError:
|
||||
file_name = os.path.basename(url)
|
||||
job.destination = job.destination / file_name
|
||||
dest = job.destination
|
||||
else:
|
||||
dest = job.destination
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if dest.exists():
|
||||
job.bytes = dest.stat().st_size
|
||||
header["Range"] = f"bytes={job.bytes}-"
|
||||
open_mode = "ab"
|
||||
resp = self._requests.get(job.source, headers=header, stream=True) # new request with range
|
||||
resp = self._requests.get(url, headers=header, stream=True) # new request with range
|
||||
|
||||
if exist_size > content_length:
|
||||
self._logger.warning("corrupt existing file found. re-downloading")
|
||||
@ -318,11 +378,11 @@ class DownloadQueue(DownloadQueueBase):
|
||||
return
|
||||
|
||||
if resp.status_code == 206 or exist_size > 0:
|
||||
self._logger.warning(f"{dest}: partial file found. Resuming...")
|
||||
self._logger.warning(f"{dest}: partial file found. Resuming")
|
||||
elif resp.status_code != 200:
|
||||
raise HTTPError(resp.reason)
|
||||
else:
|
||||
self._logger.info(f"{dest}: Downloading...")
|
||||
self._logger.info(f"{job.source}: Downloading {job.destination}")
|
||||
|
||||
report_delta = job.total_bytes / 100 # report every 1% change
|
||||
last_report_bytes = 0
|
||||
@ -338,8 +398,13 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._update_job_status(job)
|
||||
|
||||
self._update_job_status(job, DownloadJobStatus.COMPLETED)
|
||||
except DuplicateModelException as excp:
|
||||
self._logger.error(f"A model with the same hash as {dest} is already installed.")
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
except Exception as excp:
|
||||
self._logger.error(f"An error occurred while downloading {dest}: {str(excp)}")
|
||||
self._logger.error(f"An error occurred while downloading/installing {dest}: {str(excp)}")
|
||||
traceback.print_exception(excp)
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
|
||||
@ -363,8 +428,12 @@ class DownloadQueue(DownloadQueueBase):
|
||||
elif new_status in [DownloadJobStatus.COMPLETED, DownloadJobStatus.ERROR]:
|
||||
job.job_ended = time.time()
|
||||
if job.event_handlers:
|
||||
for handler in job.event_handlers:
|
||||
handler(job)
|
||||
try:
|
||||
for handler in job.event_handlers:
|
||||
handler(job)
|
||||
except Exception as excp:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
job.error = excp
|
||||
|
||||
def _download_repoid(self, job: DownloadJobBase):
|
||||
"""Download a job that holds a huggingface repoid."""
|
||||
@ -410,8 +479,8 @@ class DownloadQueue(DownloadQueueBase):
|
||||
access_token=job.access_token,
|
||||
)
|
||||
except Exception as excp:
|
||||
job.status = DownloadJobStatus.ERROR
|
||||
job.error = excp
|
||||
self._update_job_status(job, DownloadJobStatus.ERROR)
|
||||
self._logger.error(job.error)
|
||||
finally:
|
||||
subqueue.join()
|
||||
@ -423,7 +492,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self,
|
||||
repo_id: str,
|
||||
variant: Optional[str] = None,
|
||||
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], Dict[str, str]]:
|
||||
) -> Tuple[List[Tuple[AnyHttpUrl, Path, Path]], ModelSourceMetadata]:
|
||||
"""Given a repo_id and an optional variant, return list of URLs to download to get the model."""
|
||||
model_info = HfApi().model_info(repo_id=repo_id, files_metadata=True)
|
||||
sibs = model_info.siblings
|
||||
@ -440,7 +509,12 @@ class DownloadQueue(DownloadQueueBase):
|
||||
(hf_hub_url(repo_id, filename=x.as_posix()), x.parent or Path("."), x.name, sizes[x.as_posix()])
|
||||
for x in self._select_variants(paths, variant)
|
||||
]
|
||||
return (urls, {"cardData": model_info.cardData, "tags": model_info.tags, "author": model_info.author})
|
||||
return (
|
||||
urls,
|
||||
ModelSourceMetadata(
|
||||
license=model_info.cardData.get("license"), tags=model_info.tags, author=model_info.author
|
||||
),
|
||||
)
|
||||
|
||||
def _select_variants(self, paths: List[str], variant: Optional[str] = None) -> Set[Path]:
|
||||
"""Select the proper variant files from a list of HuggingFace repo_id paths."""
|
||||
|
@ -57,7 +57,7 @@ from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from .search import ModelSearch
|
||||
from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException
|
||||
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase
|
||||
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata
|
||||
from .hash import FastModelHash
|
||||
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
|
||||
from .config import (
|
||||
@ -250,9 +250,9 @@ class ModelInstall(ModelInstallBase):
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.getLogger()
|
||||
self._logger = logger or InvokeAILogger.getLogger(config=self._config)
|
||||
self._store = store or ModelConfigStoreYAML(self._config.model_conf_path)
|
||||
self._download_queue = download or DownloadQueue()
|
||||
self._download_queue = download or DownloadQueue(config=self._config)
|
||||
self._async_installs = dict()
|
||||
self._tmpdir = None
|
||||
|
||||
@ -274,6 +274,14 @@ class ModelInstall(ModelInstallBase):
|
||||
if info.model_type == ModelType.Main and info.format == ModelFormat.Checkpoint:
|
||||
try:
|
||||
config_file = self._legacy_configs[info.base_type][info.variant_type]
|
||||
if isinstance(config_file, dict): # need another tier for sd-2.x models
|
||||
if prediction_type := info.prediction_type:
|
||||
config_file = config_file[prediction_type]
|
||||
else:
|
||||
self._logger.warning(
|
||||
f"Could not infer prediction type for {model_path.stem}. Guessing 'v_prediction' for a SD-2 768 pixel model"
|
||||
)
|
||||
config_file = config_file[SchedulerPredictionType.VPrediction]
|
||||
except KeyError as exc:
|
||||
raise InvalidModelException("Configuration file for this checkpoint could not be determined") from exc
|
||||
registration_data.update(
|
||||
@ -314,20 +322,18 @@ class ModelInstall(ModelInstallBase):
|
||||
self._async_installs[source] = None
|
||||
|
||||
def complete_installation(job: DownloadJobBase):
|
||||
self._logger.info(f"{job.source}: {job.status} filename={job.destination}({job.bytes}/{job.total_bytes})")
|
||||
if job.status == "completed":
|
||||
id = self.install(job.destination)
|
||||
info = self._store.get_model(id)
|
||||
info.description = f"Downloaded model {info.name}"
|
||||
info.source_url = str(job.source)
|
||||
if card_data := job.metadata.get("cardData"):
|
||||
info.license = card_data.get("license")
|
||||
if author := job.metadata.get("author"):
|
||||
info.author = author
|
||||
if tags := job.metadata.get("tags"):
|
||||
info.tags = tags
|
||||
self._store.update_model(id, info)
|
||||
self._async_installs[job.source] = id
|
||||
model_id = self.install(job.destination)
|
||||
info = self._store.get_model(model_id)
|
||||
info.source = str(job.source)
|
||||
metadata: ModelSourceMetadata = job.metadata
|
||||
info.description = metadata.description or f"Downloaded model {info.name}"
|
||||
info.author = metadata.author
|
||||
info.tags = metadata.tags
|
||||
info.license = metadata.license
|
||||
info.thumbnail_url = metadata.thumbnail_url
|
||||
self._store.update_model(model_id, info)
|
||||
self._async_installs[job.source] = model_id
|
||||
jobs = queue.list_jobs()
|
||||
if len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir = None
|
||||
@ -336,9 +342,8 @@ class ModelInstall(ModelInstallBase):
|
||||
# will be deleted before the job actually runs.
|
||||
# Better to do the cleanup in the callback
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
return queue.create_download_job(
|
||||
source=source, destdir=self._tmpdir.name, event_handlers=[complete_installation]
|
||||
)
|
||||
job = queue.create_download_job(source=source, destdir=self._tmpdir.name)
|
||||
job.add_event_handler(complete_installation)
|
||||
|
||||
def wait_for_downloads(self) -> Dict[str, str]: # noqa D102
|
||||
self._download_queue.join()
|
||||
|
@ -206,7 +206,7 @@ class ModelConfigStoreSQL(ModelConfigStore):
|
||||
except sqlite3.IntegrityError as e:
|
||||
self._conn.rollback()
|
||||
if "UNIQUE constraint failed" in str(e):
|
||||
raise DuplicateModelException from e
|
||||
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
|
||||
else:
|
||||
raise e
|
||||
except sqlite3.Error as e:
|
||||
|
@ -116,7 +116,10 @@ class ModelConfigStoreYAML(ModelConfigStore):
|
||||
try:
|
||||
self._lock.acquire()
|
||||
if key in self._config:
|
||||
raise DuplicateModelException(f"Duplicate key {key} for model named '{record.name}'")
|
||||
existing_model = self.get_model(key)
|
||||
raise DuplicateModelException(
|
||||
f"Can't save {record.name} because a model named '{existing_model.name}' is already stored with the same key '{key}'"
|
||||
)
|
||||
self._config[key] = dict_fields
|
||||
self._commit()
|
||||
finally:
|
||||
|
@ -181,6 +181,7 @@ import urllib.parse
|
||||
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
|
||||
@ -352,9 +353,8 @@ class InvokeAILogger(object):
|
||||
loggers = dict()
|
||||
|
||||
@classmethod
|
||||
def getLogger(
|
||||
cls, name: str = "InvokeAI", config: InvokeAIAppConfig = InvokeAIAppConfig.get_config()
|
||||
) -> logging.Logger:
|
||||
def getLogger(cls, name: str = "InvokeAI", config: Optional[InvokeAIAppConfig] = None) -> logging.Logger:
|
||||
config = config or InvokeAIAppConfig.get_config()
|
||||
if name in cls.loggers:
|
||||
logger = cls.loggers[name]
|
||||
logger.handlers.clear()
|
||||
|
@ -54,8 +54,14 @@ session.mount(
|
||||
},
|
||||
),
|
||||
)
|
||||
# not found
|
||||
# not found test
|
||||
session.mount("http://www.civitai.com/models/broken", TestAdapter(b"Not found", status=404))
|
||||
|
||||
# prevent us from going to civitai to get metadata
|
||||
session.mount("https://civitai.com/api/download/models/", TestAdapter(b"Not found", status=404))
|
||||
session.mount("https://civitai.com/api/v1/models/", TestAdapter(b"Not found", status=404))
|
||||
session.mount("https://civitai.com/api/v1/model-versions/", TestAdapter(b"Not found", status=404))
|
||||
|
||||
# specifies a content disposition that may overwrite files in the parent directory
|
||||
session.mount(
|
||||
"http://www.civitai.com/models/malicious",
|
||||
|
Reference in New Issue
Block a user