mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add remote_repo_api_key config to be added as a token query param for all remote url model downloads
This commit is contained in:
parent
952d97741e
commit
73a190fb6e
@ -287,6 +287,9 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||
|
||||
# MODEL IMPORT
|
||||
remote_repo_api_key : Optional[str] = Field(default=os.environ.get("INVOKEAI_REMOTE_REPO_API_KEY"), description="API key used when downloading remote repositories", json_schema_extra=Categories.Other)
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
@ -241,12 +241,15 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def _do_download(self, job: DownloadJob) -> None:
|
||||
"""Do the actual download."""
|
||||
url = job.source
|
||||
query_params = url.query_params()
|
||||
if job.access_token:
|
||||
query_params.append(("access_token", job.access_token))
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
|
||||
# Make a streaming request. This will retrieve headers including
|
||||
# content-length and content-disposition, but not fetch any content itself
|
||||
resp = self._requests.get(str(url), headers=header, stream=True)
|
||||
resp = self._requests.get(str(url), params=query_params, headers=header, stream=True)
|
||||
if not resp.ok:
|
||||
raise HTTPError(resp.reason)
|
||||
|
||||
|
@ -199,7 +199,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=access_token,
|
||||
access_token=self.app_config.remote_repo_api_key,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
|
Loading…
Reference in New Issue
Block a user