feat(mm): support generic API tokens via regex/token pairs in config

A list of regex and token pairs is accepted. As a file is downloaded by the model installer, the URL is tested against the provided regex/token pairs. The token for the first matching regex is used during download, added as a bearer token.
This commit is contained in:
psychedelicious 2024-03-08 13:32:26 +11:00
parent b6065d6328
commit 576bb4a61d
3 changed files with 27 additions and 7 deletions

View File

@ -170,11 +170,12 @@ two configs are kept in separate sections of the config file:
from __future__ import annotations
import os
import re
from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional
from omegaconf import DictConfig, OmegaConf
from pydantic import Field
from pydantic import BaseModel, Field, field_validator
from pydantic.config import JsonDict
from pydantic_settings import SettingsConfigDict
@ -205,6 +206,21 @@ class Categories(object):
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
class URLRegexToken(BaseModel):
url_regex: str = Field(description="Regular expression to match against the URL")
token: str = Field(description="Token to use when the URL matches the regex")
@field_validator("url_regex")
@classmethod
def validate_url_regex(cls, v: str) -> str:
"""Validate that the value is a valid regex."""
try:
re.compile(v)
except re.error as e:
raise ValueError(f"Invalid regex: {e}")
return v
class InvokeAIAppConfig(InvokeAISettings):
"""Configuration object for InvokeAI App."""
@ -288,7 +304,7 @@ class InvokeAIAppConfig(InvokeAISettings):
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
# MODEL IMPORT
remote_repo_api_key : Optional[str] = Field(default=os.environ.get("INVOKEAI_REMOTE_REPO_API_KEY"), description="API key used when downloading remote repositories", json_schema_extra=Categories.Other)
remote_api_tokens : Optional[list[URLRegexToken]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.", json_schema_extra=Categories.Other)
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)

View File

@ -241,15 +241,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download."""
url = job.source
query_params = url.query_params()
if job.access_token:
query_params.append(("access_token", job.access_token))
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb"
# Make a streaming request. This will retrieve headers including
# content-length and content-disposition, but not fetch any content itself
resp = self._requests.get(str(url), params=query_params, headers=header, stream=True)
resp = self._requests.get(str(url), headers=header, stream=True)
if not resp.ok:
raise HTTPError(resp.reason)

View File

@ -197,9 +197,16 @@ class ModelInstallService(ModelInstallServiceBase):
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=self.app_config.remote_repo_api_key,
access_token=_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")