move access token regex matching into download queue

This commit is contained in:
Lincoln Stein
2024-05-05 21:00:31 -04:00
parent 8e5e9b53d6
commit f211c95dbc
7 changed files with 69 additions and 29 deletions

View File

@ -2,14 +2,19 @@
import re
import time
from contextlib import contextmanager
from pathlib import Path
from typing import Generator
import pytest
from pydantic.networks import AnyHttpUrl
from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings
@ -34,6 +39,17 @@ def session() -> Session:
),
)
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test
# missing the content length
sess.mount(
@ -205,3 +221,37 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop()
@contextmanager
def clear_config() -> Generator[None, None, None]:
try:
yield None
finally:
get_config.cache_clear()
def test_tokens(tmp_path: Path, session: Session):
with clear_config():
config = get_config()
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
queue = DownloadQueueService(requests_session=session)
queue.start()
# this one has an access token assigned
job1 = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
)
# this one doesn't
job2 = queue.download(
source=AnyHttpUrl(
"http://www.huggingface.co/foo.txt",
),
dest=tmp_path,
)
queue.join()
# this token is defined in the temporary root invokeai.yaml
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
assert job1.access_token == "cv_12345"
assert job2.access_token is None
queue.stop()