Add remote_repo_api_key config to be added as a token query param for all remote url model downloads

This commit is contained in:
Brandon Rising 2024-03-07 11:39:20 -05:00 committed by psychedelicious
parent 952d97741e
commit 73a190fb6e
3 changed files with 8 additions and 2 deletions

View File

@ -287,6 +287,9 @@ class InvokeAIAppConfig(InvokeAISettings):
deny_nodes : Optional[List[str]] = Field(default=None, description="List of nodes to deny. Omit to deny none.", json_schema_extra=Categories.Nodes)
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
# MODEL IMPORT
remote_repo_api_key : Optional[str] = Field(default=os.environ.get("INVOKEAI_REMOTE_REPO_API_KEY"), description="API key used when downloading remote repositories", json_schema_extra=Categories.Other)
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
max_cache_size : Optional[float] = Field(default=None, gt=0, description="Maximum memory amount used by model cache for rapid switching", json_schema_extra=Categories.MemoryPerformance)

View File

@ -241,12 +241,15 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
query_params = url.query_params()
if job.access_token:
query_params.append(("access_token", job.access_token))
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
# Make a streaming request. This will retrieve headers including
# content-length and content-disposition, but not fetch any content itself
resp = self._requests.get(str(url), headers=header, stream=True)
resp = self._requests.get(str(url), params=query_params, headers=header, stream=True)
if not resp.ok:
raise HTTPError(resp.reason)

View File

@ -199,7 +199,7 @@ class ModelInstallService(ModelInstallServiceBase):
elif re.match(r"^https?://[^/]+", source):
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=access_token,
access_token=self.app_config.remote_repo_api_key,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")