feat(mm): support generic API tokens via regex/token pairs in config

A list of regex and token pairs is accepted. As a file is downloaded by the model installer, the URL is tested against the provided regex/token pairs. The token for the first matching regex is used during download, added as a bearer token.
This commit is contained in:
psychedelicious 2024-03-08 13:32:26 +11:00
parent b6065d6328
commit 576bb4a61d
3 changed files with 27 additions and 7 deletions

View File

@ -170,11 +170,12 @@ two configs are kept in separate sections of the config file:
from __future__ import annotations from __future__ import annotations
import os import os
import re
from pathlib import Path from pathlib import Path
from typing import Any, ClassVar, Dict, List, Literal, Optional from typing import Any, ClassVar, Dict, List, Literal, Optional
from omegaconf import DictConfig, OmegaConf from omegaconf import DictConfig, OmegaConf
from pydantic import Field from pydantic import BaseModel, Field, field_validator
from pydantic.config import JsonDict from pydantic.config import JsonDict
from pydantic_settings import SettingsConfigDict from pydantic_settings import SettingsConfigDict
@ -205,6 +206,21 @@ class Categories(object):
MemoryPerformance: JsonDict = {"category": "Memory/Performance"} MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
class URLRegexToken(BaseModel):
url_regex: str = Field(description="Regular expression to match against the URL")
token: str = Field(description="Token to use when the URL matches the regex")
@field_validator("url_regex")
@classmethod
def validate_url_regex(cls, v: str) -> str:
"""Validate that the value is a valid regex."""
try:
re.compile(v)
except re.error as e:
raise ValueError(f"Invalid regex: {e}")
return v
class InvokeAIAppConfig(InvokeAISettings): class InvokeAIAppConfig(InvokeAISettings):
"""Configuration object for InvokeAI App.""" """Configuration object for InvokeAI App."""
@ -288,7 +304,7 @@ class InvokeAIAppConfig(InvokeAISettings):
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes) node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
# MODEL IMPORT # MODEL IMPORT
remote_repo_api_key : Optional[str] = Field(default=os.environ.get("INVOKEAI_REMOTE_REPO_API_KEY"), description="API key used when downloading remote repositories", json_schema_extra=Categories.Other) remote_api_tokens : Optional[list[URLRegexToken]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.", json_schema_extra=Categories.Other)
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES # DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance) always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)

View File

@ -241,15 +241,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _do_download(self, job: DownloadJob) -> None: def _do_download(self, job: DownloadJob) -> None:
"""Do the actual download.""" """Do the actual download."""
url = job.source url = job.source
query_params = url.query_params()
if job.access_token:
query_params.append(("access_token", job.access_token))
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {} header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
open_mode = "wb" open_mode = "wb"
# Make a streaming request. This will retrieve headers including # Make a streaming request. This will retrieve headers including
# content-length and content-disposition, but not fetch any content itself # content-length and content-disposition, but not fetch any content itself
resp = self._requests.get(str(url), params=query_params, headers=header, stream=True) resp = self._requests.get(str(url), headers=header, stream=True)
if not resp.ok: if not resp.ok:
raise HTTPError(resp.reason) raise HTTPError(resp.reason)

View File

@ -197,9 +197,16 @@ class ModelInstallService(ModelInstallServiceBase):
access_token=access_token, access_token=access_token,
) )
elif re.match(r"^https?://[^/]+", source): elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource( source_obj = URLModelSource(
url=AnyHttpUrl(source), url=AnyHttpUrl(source),
access_token=self.app_config.remote_repo_api_key, access_token=_token,
) )
else: else:
raise ValueError(f"Unsupported model source: '{source}'") raise ValueError(f"Unsupported model source: '{source}'")