mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move access token regex matching into download queue
This commit is contained in:
@ -2,14 +2,19 @@
|
||||
|
||||
import re
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
@ -34,6 +39,17 @@ def session() -> Session:
|
||||
),
|
||||
)
|
||||
|
||||
sess.mount(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": 'filename="foo.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
@ -205,3 +221,37 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert events[-1].event_name == "download_cancelled"
|
||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def test_tokens(tmp_path: Path, session: Session):
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||
queue = DownloadQueueService(requests_session=session)
|
||||
queue.start()
|
||||
# this one has an access token assigned
|
||||
job1 = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
)
|
||||
# this one doesn't
|
||||
job2 = queue.download(
|
||||
source=AnyHttpUrl(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
),
|
||||
dest=tmp_path,
|
||||
)
|
||||
queue.join()
|
||||
# this token is defined in the temporary root invokeai.yaml
|
||||
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
|
||||
assert job1.access_token == "cv_12345"
|
||||
assert job2.access_token is None
|
||||
queue.stop()
|
||||
|
Reference in New Issue
Block a user