mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(mm): support generic API tokens via regex/token pairs in config
A list of regex and token pairs is accepted. As a file is downloaded by the model installer, the URL is tested against the provided regex/token pairs. The token for the first matching regex is used during download, added as a bearer token.
This commit is contained in:
parent
b6065d6328
commit
576bb4a61d
@ -170,11 +170,12 @@ two configs are kept in separate sections of the config file:
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, ClassVar, Dict, List, Literal, Optional
|
||||
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic.config import JsonDict
|
||||
from pydantic_settings import SettingsConfigDict
|
||||
|
||||
@ -205,6 +206,21 @@ class Categories(object):
|
||||
MemoryPerformance: JsonDict = {"category": "Memory/Performance"}
|
||||
|
||||
|
||||
class URLRegexToken(BaseModel):
|
||||
url_regex: str = Field(description="Regular expression to match against the URL")
|
||||
token: str = Field(description="Token to use when the URL matches the regex")
|
||||
|
||||
@field_validator("url_regex")
|
||||
@classmethod
|
||||
def validate_url_regex(cls, v: str) -> str:
|
||||
"""Validate that the value is a valid regex."""
|
||||
try:
|
||||
re.compile(v)
|
||||
except re.error as e:
|
||||
raise ValueError(f"Invalid regex: {e}")
|
||||
return v
|
||||
|
||||
|
||||
class InvokeAIAppConfig(InvokeAISettings):
|
||||
"""Configuration object for InvokeAI App."""
|
||||
|
||||
@ -288,7 +304,7 @@ class InvokeAIAppConfig(InvokeAISettings):
|
||||
node_cache_size : int = Field(default=512, description="How many cached nodes to keep in memory", json_schema_extra=Categories.Nodes)
|
||||
|
||||
# MODEL IMPORT
|
||||
remote_repo_api_key : Optional[str] = Field(default=os.environ.get("INVOKEAI_REMOTE_REPO_API_KEY"), description="API key used when downloading remote repositories", json_schema_extra=Categories.Other)
|
||||
remote_api_tokens : Optional[list[URLRegexToken]] = Field(default=None, description="List of regular expression and token pairs used when downloading models from URLs. The download URL is tested against the regex, and if it matches, the token is provided in as a Bearer token.", json_schema_extra=Categories.Other)
|
||||
|
||||
# DEPRECATED FIELDS - STILL HERE IN ORDER TO OBTAN VALUES FROM PRE-3.1 CONFIG FILES
|
||||
always_use_cpu : bool = Field(default=False, description="If true, use the CPU for rendering even if a GPU is available.", json_schema_extra=Categories.MemoryPerformance)
|
||||
|
@ -241,15 +241,12 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def _do_download(self, job: DownloadJob) -> None:
|
||||
"""Do the actual download."""
|
||||
url = job.source
|
||||
query_params = url.query_params()
|
||||
if job.access_token:
|
||||
query_params.append(("access_token", job.access_token))
|
||||
header = {"Authorization": f"Bearer {job.access_token}"} if job.access_token else {}
|
||||
open_mode = "wb"
|
||||
|
||||
# Make a streaming request. This will retrieve headers including
|
||||
# content-length and content-disposition, but not fetch any content itself
|
||||
resp = self._requests.get(str(url), params=query_params, headers=header, stream=True)
|
||||
resp = self._requests.get(str(url), headers=header, stream=True)
|
||||
if not resp.ok:
|
||||
raise HTTPError(resp.reason)
|
||||
|
||||
|
@ -197,9 +197,16 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
_token = access_token
|
||||
if _token is None:
|
||||
for pair in self.app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, source):
|
||||
_token = pair.token
|
||||
break
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=self.app_config.remote_repo_api_key,
|
||||
access_token=_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
|
Loading…
Reference in New Issue
Block a user