correctly download the selected version of a civitai model

This commit is contained in:
Lincoln Stein
2023-09-22 22:54:46 -04:00
parent d2cdbe5c4e
commit d5d517d2fa
2 changed files with 29 additions and 15 deletions

View File

@ -39,8 +39,9 @@ DOWNLOAD_CHUNK_SIZE = 100000
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
# endpoint for civitai get-model API
CIVITAI_MODEL_DOWNLOAD = "https://civitai.com/api/download/models/"
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
@ -338,7 +339,7 @@ class DownloadQueue(DownloadQueueBase):
metadata_url = url
try:
# a Civitai download URL
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
if match := re.match(CIVITAI_MODEL_DOWNLOAD, metadata_url):
version = match.group(1)
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
@ -347,15 +348,19 @@ class DownloadQueue(DownloadQueueBase):
if resp["trainedWords"]
else resp["description"]
)
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"])
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
# a Civitai model page
if match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url):
# a Civitai model page with the version
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, metadata_url):
model = match.group(1)
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
version = int(match.group(2))
# and without
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url):
model = match.group(1)
version = None
# note that we munge the URL here to get the download URL of the first model
url = resp["modelVersions"][0]["downloadUrl"]
if model:
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
metadata.author = metadata.author or resp["creator"]["username"]
metadata.tags = metadata.tags or resp["tags"]
@ -363,19 +368,27 @@ class DownloadQueue(DownloadQueueBase):
metadata.license
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
)
first_version = resp["modelVersions"][0]
metadata.thumbnail_url = metadata.thumbnail_url or first_version.get("url")
if version:
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
version_data = versions[0]
else:
version_data = resp["modelVersions"][0] # first one
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
metadata.description = metadata.description or (
f"Trigger terms: {(', ').join(first_version.get('trainedWords'))}"
if first_version.get("trainedWords")
else first_version.get("description")
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
if version_data.get("trainedWords")
else version_data.get("description")
)
download_url = version_data["downloadUrl"]
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
self._logger.warn(excp)
# return the download url
return url
return download_url
def _download_with_resume(self, job: DownloadJobBase):
"""Do the actual download."""

View File

@ -501,7 +501,8 @@ class ModelInstall(ModelInstallBase):
else self._complete_installation_handler
)
job.probe_override = probe_override
job.metadata = metadata
if metadata:
job.metadata = metadata
job.add_event_handler(handler)
self._async_installs[source] = None