move access token regex matching into download queue

This commit is contained in:
Lincoln Stein 2024-05-05 21:00:31 -04:00
parent 8e5e9b53d6
commit f211c95dbc
7 changed files with 69 additions and 29 deletions

View File

@ -93,7 +93,7 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache( conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True) ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
) )
download_queue_service = DownloadQueueService(event_bus=events) download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images") model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager( model_manager = ModelManagerService.build_model_manager(
app_config=configuration, app_config=configuration,

View File

@ -15,6 +15,7 @@ from pydantic.networks import AnyHttpUrl
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -40,15 +41,18 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__( def __init__(
self, self,
max_parallel_dl: int = 5, max_parallel_dl: int = 5,
app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional[EventServiceBase] = None, event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None, requests_session: Optional[requests.sessions.Session] = None,
): ):
""" """
Initialize DownloadQueue. Initialize DownloadQueue.
:param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5]. :param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests. :param requests_session: Optional requests.sessions.Session object, for unit tests.
""" """
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {} self._jobs: Dict[int, DownloadJob] = {}
self._next_job_id = 0 self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue() self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
@ -139,7 +143,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source, source=source,
dest=dest, dest=dest,
priority=priority, priority=priority,
access_token=access_token, access_token=access_token or self._lookup_access_token(source),
) )
self.submit_download_job( self.submit_download_job(
job, job,
@ -333,6 +337,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path: def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading") return path.with_name(path.name + ".downloading")
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
print(self._app_config)
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
token = pair.token
break
return token
def _signal_job_started(self, job: DownloadJob) -> None: def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING job.status = DownloadJobStatus.RUNNING
if job.on_start: if job.on_start:

View File

@ -222,16 +222,9 @@ class ModelInstallService(ModelInstallServiceBase):
access_token=access_token, access_token=access_token,
) )
elif re.match(r"^https?://[^/]+", source): elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource( source_obj = URLModelSource(
url=AnyHttpUrl(source), url=AnyHttpUrl(source),
access_token=_token, access_token=access_token,
) )
else: else:
raise ValueError(f"Unsupported model source: '{source}'") raise ValueError(f"Unsupported model source: '{source}'")

View File

@ -75,8 +75,6 @@ class ModelManagerServiceBase(ABC):
def load_ckpt_from_url( def load_ckpt_from_url(
self, self,
source: str | AnyHttpUrl, source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
@ -94,9 +92,6 @@ class ModelManagerServiceBase(ABC):
Args: Args:
source: A URL or a string that can be converted in one. Repo_ids source: A URL or a string that can be converted in one. Repo_ids
do not work here. do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any] loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns: Returns:

View File

@ -106,8 +106,6 @@ class ModelManagerService(ModelManagerServiceBase):
def load_ckpt_from_url( def load_ckpt_from_url(
self, self,
source: str | AnyHttpUrl, source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
@ -125,13 +123,10 @@ class ModelManagerService(ModelManagerServiceBase):
Args: Args:
source: A URL or a string that can be converted in one. Repo_ids source: A URL or a string that can be converted in one. Repo_ids
do not work here. do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any] loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns: Returns:
A LoadedModel object. A LoadedModel object.
""" """
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout) model_path = self.install.download_and_cache_ckpt(source=source)
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader) return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)

View File

@ -496,8 +496,6 @@ class ModelsInterface(InvocationContextInterface):
def load_ckpt_from_url( def load_ckpt_from_url(
self, self,
source: str | AnyHttpUrl, source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None, loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel: ) -> LoadedModel:
""" """
@ -515,17 +513,12 @@ class ModelsInterface(InvocationContextInterface):
Args: Args:
source: A URL or a string that can be converted in one. Repo_ids source: A URL or a string that can be converted in one. Repo_ids
do not work here. do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any] loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns: Returns:
A LoadedModel object. A LoadedModel object.
""" """
result: LoadedModel = self._services.model_manager.load_ckpt_from_url( result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader)
source=source, access_token=access_token, timeout=timeout, loader=loader
)
return result return result

View File

@ -2,14 +2,19 @@
import re import re
import time import time
from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Generator
import pytest import pytest
from pydantic.networks import AnyHttpUrl from pydantic.networks import AnyHttpUrl
from requests.sessions import Session from requests.sessions import Session
from requests_testadapter import TestAdapter, TestSession from requests_testadapter import TestAdapter, TestSession
from invokeai.app.services.config import get_config
from invokeai.app.services.config.config_default import URLRegexTokenPair
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
from tests.test_nodes import TestEventService from tests.test_nodes import TestEventService
# Prevent pytest deprecation warnings # Prevent pytest deprecation warnings
@ -34,6 +39,17 @@ def session() -> Session:
), ),
) )
sess.mount(
"http://www.huggingface.co/foo.txt",
TestAdapter(
content,
headers={
"Content-Length": len(content),
"Content-Disposition": 'filename="foo.safetensors"',
},
),
)
# here are some malformed URLs to test # here are some malformed URLs to test
# missing the content length # missing the content length
sess.mount( sess.mount(
@ -205,3 +221,37 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
assert events[-1].event_name == "download_cancelled" assert events[-1].event_name == "download_cancelled"
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345" assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
queue.stop() queue.stop()
@contextmanager
def clear_config() -> Generator[None, None, None]:
try:
yield None
finally:
get_config.cache_clear()
def test_tokens(tmp_path: Path, session: Session):
with clear_config():
config = get_config()
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
queue = DownloadQueueService(requests_session=session)
queue.start()
# this one has an access token assigned
job1 = queue.download(
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
dest=tmp_path,
)
# this one doesn't
job2 = queue.download(
source=AnyHttpUrl(
"http://www.huggingface.co/foo.txt",
),
dest=tmp_path,
)
queue.join()
# this token is defined in the temporary root invokeai.yaml
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
assert job1.access_token == "cv_12345"
assert job2.access_token is None
queue.stop()