mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move access token regex matching into download queue
This commit is contained in:
parent
8e5e9b53d6
commit
f211c95dbc
@ -93,7 +93,7 @@ class ApiDependencies:
|
||||
conditioning = ObjectSerializerForwardCache(
|
||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||
)
|
||||
download_queue_service = DownloadQueueService(event_bus=events)
|
||||
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||
model_manager = ModelManagerService.build_model_manager(
|
||||
app_config=configuration,
|
||||
|
@ -15,6 +15,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests import HTTPError
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.util.misc import get_iso_timestamp
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -40,15 +41,18 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_parallel_dl: int = 5,
|
||||
app_config: Optional[InvokeAIAppConfig] = None,
|
||||
event_bus: Optional[EventServiceBase] = None,
|
||||
requests_session: Optional[requests.sessions.Session] = None,
|
||||
):
|
||||
"""
|
||||
Initialize DownloadQueue.
|
||||
|
||||
:param app_config: InvokeAIAppConfig object
|
||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||
"""
|
||||
self._app_config = app_config or get_config()
|
||||
self._jobs: Dict[int, DownloadJob] = {}
|
||||
self._next_job_id = 0
|
||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||
@ -139,7 +143,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
source=source,
|
||||
dest=dest,
|
||||
priority=priority,
|
||||
access_token=access_token,
|
||||
access_token=access_token or self._lookup_access_token(source),
|
||||
)
|
||||
self.submit_download_job(
|
||||
job,
|
||||
@ -333,6 +337,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
def _in_progress_path(self, path: Path) -> Path:
|
||||
return path.with_name(path.name + ".downloading")
|
||||
|
||||
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
print(self._app_config)
|
||||
token = None
|
||||
for pair in self._app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, str(source)):
|
||||
token = pair.token
|
||||
break
|
||||
return token
|
||||
|
||||
def _signal_job_started(self, job: DownloadJob) -> None:
|
||||
job.status = DownloadJobStatus.RUNNING
|
||||
if job.on_start:
|
||||
|
@ -222,16 +222,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
access_token=access_token,
|
||||
)
|
||||
elif re.match(r"^https?://[^/]+", source):
|
||||
# Pull the token from config if it exists and matches the URL
|
||||
_token = access_token
|
||||
if _token is None:
|
||||
for pair in self.app_config.remote_api_tokens or []:
|
||||
if re.search(pair.url_regex, source):
|
||||
_token = pair.token
|
||||
break
|
||||
source_obj = URLModelSource(
|
||||
url=AnyHttpUrl(source),
|
||||
access_token=_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model source: '{source}'")
|
||||
|
@ -75,8 +75,6 @@ class ModelManagerServiceBase(ABC):
|
||||
def load_ckpt_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
@ -94,9 +92,6 @@ class ModelManagerServiceBase(ABC):
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
access_token: Optional access token for restricted resources.
|
||||
timeout: Wait up to the indicated number of seconds before timing
|
||||
out long downloads.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
|
@ -106,8 +106,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
def load_ckpt_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
@ -125,13 +123,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
access_token: Optional access token for restricted resources.
|
||||
timeout: Wait up to the indicated number of seconds before timing
|
||||
out long downloads.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout)
|
||||
model_path = self.install.download_and_cache_ckpt(source=source)
|
||||
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)
|
||||
|
@ -496,8 +496,6 @@ class ModelsInterface(InvocationContextInterface):
|
||||
def load_ckpt_from_url(
|
||||
self,
|
||||
source: str | AnyHttpUrl,
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
"""
|
||||
@ -515,17 +513,12 @@ class ModelsInterface(InvocationContextInterface):
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
access_token: Optional access token for restricted resources.
|
||||
timeout: Wait up to the indicated number of seconds before timing
|
||||
out long downloads.
|
||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
"""
|
||||
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(
|
||||
source=source, access_token=access_token, timeout=timeout, loader=loader
|
||||
)
|
||||
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader)
|
||||
return result
|
||||
|
||||
|
||||
|
@ -2,14 +2,19 @@
|
||||
|
||||
import re
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from requests.sessions import Session
|
||||
from requests_testadapter import TestAdapter, TestSession
|
||||
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
from tests.test_nodes import TestEventService
|
||||
|
||||
# Prevent pytest deprecation warnings
|
||||
@ -34,6 +39,17 @@ def session() -> Session:
|
||||
),
|
||||
)
|
||||
|
||||
sess.mount(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
TestAdapter(
|
||||
content,
|
||||
headers={
|
||||
"Content-Length": len(content),
|
||||
"Content-Disposition": 'filename="foo.safetensors"',
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# here are some malformed URLs to test
|
||||
# missing the content length
|
||||
sess.mount(
|
||||
@ -205,3 +221,37 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
||||
assert events[-1].event_name == "download_cancelled"
|
||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||
queue.stop()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def clear_config() -> Generator[None, None, None]:
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
get_config.cache_clear()
|
||||
|
||||
|
||||
def test_tokens(tmp_path: Path, session: Session):
|
||||
with clear_config():
|
||||
config = get_config()
|
||||
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||
queue = DownloadQueueService(requests_session=session)
|
||||
queue.start()
|
||||
# this one has an access token assigned
|
||||
job1 = queue.download(
|
||||
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||
dest=tmp_path,
|
||||
)
|
||||
# this one doesn't
|
||||
job2 = queue.download(
|
||||
source=AnyHttpUrl(
|
||||
"http://www.huggingface.co/foo.txt",
|
||||
),
|
||||
dest=tmp_path,
|
||||
)
|
||||
queue.join()
|
||||
# this token is defined in the temporary root invokeai.yaml
|
||||
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
|
||||
assert job1.access_token == "cv_12345"
|
||||
assert job2.access_token is None
|
||||
queue.stop()
|
||||
|
Loading…
Reference in New Issue
Block a user