move access token regex matching into download queue

This commit is contained in:
Lincoln Stein 2024-05-05 21:00:31 -04:00
parent 8e5e9b53d6
commit f211c95dbc
7 changed files with 69 additions and 29 deletions

View File

@ -93,7 +93,7 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,

View File

@ -15,6 +15,7 @@ from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
@ -40,15 +41,18 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
:param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {}
self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
@ -139,7 +143,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source,
dest=dest,
priority=priority,
access_token=access_token,
access_token=access_token or self._lookup_access_token(source),
)
self.submit_download_job(
job,
@ -333,6 +337,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
print(self._app_config)
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
token = pair.token
break
return token
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
if job.on_start:

View File

@ -222,16 +222,9 @@ class ModelInstallService(ModelInstallServiceBase):
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=_token,
access_token=access_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")

View File

@ -75,8 +75,6 @@ class ModelManagerServiceBase(ABC):
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
@ -94,9 +92,6 @@ class ModelManagerServiceBase(ABC):
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:

View File

@ -106,8 +106,6 @@ class ModelManagerService(ModelManagerServiceBase):
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
@ -125,13 +123,10 @@ class ModelManagerService(ModelManagerServiceBase):
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout)
model_path = self.install.download_and_cache_ckpt(source=source)
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)

View File

@ -496,8 +496,6 @@ class ModelsInterface(InvocationContextInterface):
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
@ -515,17 +513,12 @@ class ModelsInterface(InvocationContextInterface):
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(
source=source, access_token=access_token, timeout=timeout, loader=loader
)
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader)
return result

View File

@ -2,14 +2,19 @@
import re
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
import pytest
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
@ -34,6 +39,17 @@ def session() -> Session:
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
@ -205,3 +221,37 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()
@contextmanager
def clear_config() -> Generator[None, None, None]:
try:
yield None
finally:
get_config.cache_clear()
def test_tokens(tmp_path: Path, session: Session):
with clear_config():
config = get_config()
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
queue = DownloadQueueService(requests_session=session)
queue.start()
# this one has an access token assigned
job1 = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
)
# this one doesn't
job2 = queue.download(
source=AnyHttpUrl(
"http://www.huggingface.co/foo.txt",
),
dest=tmp_path,
)
queue.join()
# this token is defined in the temporary root invokeai.yaml
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
assert job1.access_token == "cv_12345"
assert job2.access_token is None
queue.stop()