mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
correctly download the selected version of a civitai model
This commit is contained in:
@ -39,8 +39,9 @@ DOWNLOAD_CHUNK_SIZE = 100000
|
||||
STOP_JOB = DownloadJobBase(id=-99, priority=-99, source="dummy", destination="/")
|
||||
|
||||
# endpoint for civitai get-model API
|
||||
CIVITAI_MODEL_DOWNLOAD = "https://civitai.com/api/download/models/"
|
||||
CIVITAI_MODEL_DOWNLOAD = r"https://civitai.com/api/download/models/(\d+)"
|
||||
CIVITAI_MODEL_PAGE = "https://civitai.com/models/"
|
||||
CIVITAI_MODEL_PAGE_WITH_VERSION = r"https://civitai.com/models/(\d+)\?modelVersionId=(\d+)"
|
||||
CIVITAI_MODELS_ENDPOINT = "https://civitai.com/api/v1/models/"
|
||||
CIVITAI_VERSIONS_ENDPOINT = "https://civitai.com/api/v1/model-versions/"
|
||||
|
||||
@ -338,7 +339,7 @@ class DownloadQueue(DownloadQueueBase):
|
||||
metadata_url = url
|
||||
try:
|
||||
# a Civitai download URL
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD + r"(\d+)", metadata_url):
|
||||
if match := re.match(CIVITAI_MODEL_DOWNLOAD, metadata_url):
|
||||
version = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_VERSIONS_ENDPOINT + version).json()
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or resp["images"][0]["url"]
|
||||
@ -347,15 +348,19 @@ class DownloadQueue(DownloadQueueBase):
|
||||
if resp["trainedWords"]
|
||||
else resp["description"]
|
||||
)
|
||||
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"])
|
||||
metadata_url = CIVITAI_MODEL_PAGE + str(resp["modelId"]) + f"?modelVersionId={version}"
|
||||
|
||||
# a Civitai model page
|
||||
if match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url):
|
||||
# a Civitai model page with the version
|
||||
if match := re.match(CIVITAI_MODEL_PAGE_WITH_VERSION, metadata_url):
|
||||
model = match.group(1)
|
||||
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
|
||||
version = int(match.group(2))
|
||||
# and without
|
||||
elif match := re.match(CIVITAI_MODEL_PAGE + r"(\d+)", metadata_url):
|
||||
model = match.group(1)
|
||||
version = None
|
||||
|
||||
# note that we munge the URL here to get the download URL of the first model
|
||||
url = resp["modelVersions"][0]["downloadUrl"]
|
||||
if model:
|
||||
resp = self._requests.get(CIVITAI_MODELS_ENDPOINT + str(model)).json()
|
||||
|
||||
metadata.author = metadata.author or resp["creator"]["username"]
|
||||
metadata.tags = metadata.tags or resp["tags"]
|
||||
@ -363,19 +368,27 @@ class DownloadQueue(DownloadQueueBase):
|
||||
metadata.license
|
||||
or f"allowCommercialUse={resp['allowCommercialUse']}; allowDerivatives={resp['allowDerivatives']}; allowNoCredit={resp['allowNoCredit']}"
|
||||
)
|
||||
first_version = resp["modelVersions"][0]
|
||||
metadata.thumbnail_url = metadata.thumbnail_url or first_version.get("url")
|
||||
|
||||
if version:
|
||||
versions = [x for x in resp["modelVersions"] if int(x["id"]) == version]
|
||||
version_data = versions[0]
|
||||
else:
|
||||
version_data = resp["modelVersions"][0] # first one
|
||||
|
||||
metadata.thumbnail_url = version_data.get("url") or metadata.thumbnail_url
|
||||
metadata.description = metadata.description or (
|
||||
f"Trigger terms: {(', ').join(first_version.get('trainedWords'))}"
|
||||
if first_version.get("trainedWords")
|
||||
else first_version.get("description")
|
||||
f"Trigger terms: {(', ').join(version_data.get('trainedWords'))}"
|
||||
if version_data.get("trainedWords")
|
||||
else version_data.get("description")
|
||||
)
|
||||
|
||||
download_url = version_data["downloadUrl"]
|
||||
|
||||
except (HTTPError, KeyError, TypeError, JSONDecodeError) as excp:
|
||||
self._logger.warn(excp)
|
||||
|
||||
# return the download url
|
||||
return url
|
||||
return download_url
|
||||
|
||||
def _download_with_resume(self, job: DownloadJobBase):
|
||||
"""Do the actual download."""
|
||||
|
@ -501,7 +501,8 @@ class ModelInstall(ModelInstallBase):
|
||||
else self._complete_installation_handler
|
||||
)
|
||||
job.probe_override = probe_override
|
||||
job.metadata = metadata
|
||||
if metadata:
|
||||
job.metadata = metadata
|
||||
job.add_event_handler(handler)
|
||||
|
||||
self._async_installs[source] = None
|
||||
|
Reference in New Issue
Block a user