mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
move access token regex matching into download queue
This commit is contained in:
parent
8e5e9b53d6
commit
f211c95dbc
@ -93,7 +93,7 @@ class ApiDependencies:
|
|||||||
conditioning = ObjectSerializerForwardCache(
|
conditioning = ObjectSerializerForwardCache(
|
||||||
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
|
||||||
)
|
)
|
||||||
download_queue_service = DownloadQueueService(event_bus=events)
|
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
|
||||||
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
|
||||||
model_manager = ModelManagerService.build_model_manager(
|
model_manager = ModelManagerService.build_model_manager(
|
||||||
app_config=configuration,
|
app_config=configuration,
|
||||||
|
@ -15,6 +15,7 @@ from pydantic.networks import AnyHttpUrl
|
|||||||
from requests import HTTPError
|
from requests import HTTPError
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.util.misc import get_iso_timestamp
|
from invokeai.app.util.misc import get_iso_timestamp
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -40,15 +41,18 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_parallel_dl: int = 5,
|
max_parallel_dl: int = 5,
|
||||||
|
app_config: Optional[InvokeAIAppConfig] = None,
|
||||||
event_bus: Optional[EventServiceBase] = None,
|
event_bus: Optional[EventServiceBase] = None,
|
||||||
requests_session: Optional[requests.sessions.Session] = None,
|
requests_session: Optional[requests.sessions.Session] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize DownloadQueue.
|
Initialize DownloadQueue.
|
||||||
|
|
||||||
|
:param app_config: InvokeAIAppConfig object
|
||||||
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
|
||||||
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
:param requests_session: Optional requests.sessions.Session object, for unit tests.
|
||||||
"""
|
"""
|
||||||
|
self._app_config = app_config or get_config()
|
||||||
self._jobs: Dict[int, DownloadJob] = {}
|
self._jobs: Dict[int, DownloadJob] = {}
|
||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
|
||||||
@ -139,7 +143,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
source=source,
|
source=source,
|
||||||
dest=dest,
|
dest=dest,
|
||||||
priority=priority,
|
priority=priority,
|
||||||
access_token=access_token,
|
access_token=access_token or self._lookup_access_token(source),
|
||||||
)
|
)
|
||||||
self.submit_download_job(
|
self.submit_download_job(
|
||||||
job,
|
job,
|
||||||
@ -333,6 +337,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
def _in_progress_path(self, path: Path) -> Path:
|
def _in_progress_path(self, path: Path) -> Path:
|
||||||
return path.with_name(path.name + ".downloading")
|
return path.with_name(path.name + ".downloading")
|
||||||
|
|
||||||
|
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
|
||||||
|
# Pull the token from config if it exists and matches the URL
|
||||||
|
print(self._app_config)
|
||||||
|
token = None
|
||||||
|
for pair in self._app_config.remote_api_tokens or []:
|
||||||
|
if re.search(pair.url_regex, str(source)):
|
||||||
|
token = pair.token
|
||||||
|
break
|
||||||
|
return token
|
||||||
|
|
||||||
def _signal_job_started(self, job: DownloadJob) -> None:
|
def _signal_job_started(self, job: DownloadJob) -> None:
|
||||||
job.status = DownloadJobStatus.RUNNING
|
job.status = DownloadJobStatus.RUNNING
|
||||||
if job.on_start:
|
if job.on_start:
|
||||||
|
@ -222,16 +222,9 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
elif re.match(r"^https?://[^/]+", source):
|
elif re.match(r"^https?://[^/]+", source):
|
||||||
# Pull the token from config if it exists and matches the URL
|
|
||||||
_token = access_token
|
|
||||||
if _token is None:
|
|
||||||
for pair in self.app_config.remote_api_tokens or []:
|
|
||||||
if re.search(pair.url_regex, source):
|
|
||||||
_token = pair.token
|
|
||||||
break
|
|
||||||
source_obj = URLModelSource(
|
source_obj = URLModelSource(
|
||||||
url=AnyHttpUrl(source),
|
url=AnyHttpUrl(source),
|
||||||
access_token=_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported model source: '{source}'")
|
raise ValueError(f"Unsupported model source: '{source}'")
|
||||||
|
@ -75,8 +75,6 @@ class ModelManagerServiceBase(ABC):
|
|||||||
def load_ckpt_from_url(
|
def load_ckpt_from_url(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str | AnyHttpUrl,
|
||||||
access_token: Optional[str] = None,
|
|
||||||
timeout: Optional[int] = 0,
|
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
@ -94,9 +92,6 @@ class ModelManagerServiceBase(ABC):
|
|||||||
Args:
|
Args:
|
||||||
source: A URL or a string that can be converted in one. Repo_ids
|
source: A URL or a string that can be converted in one. Repo_ids
|
||||||
do not work here.
|
do not work here.
|
||||||
access_token: Optional access token for restricted resources.
|
|
||||||
timeout: Wait up to the indicated number of seconds before timing
|
|
||||||
out long downloads.
|
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -106,8 +106,6 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
def load_ckpt_from_url(
|
def load_ckpt_from_url(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str | AnyHttpUrl,
|
||||||
access_token: Optional[str] = None,
|
|
||||||
timeout: Optional[int] = 0,
|
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
@ -125,13 +123,10 @@ class ModelManagerService(ModelManagerServiceBase):
|
|||||||
Args:
|
Args:
|
||||||
source: A URL or a string that can be converted in one. Repo_ids
|
source: A URL or a string that can be converted in one. Repo_ids
|
||||||
do not work here.
|
do not work here.
|
||||||
access_token: Optional access token for restricted resources.
|
|
||||||
timeout: Wait up to the indicated number of seconds before timing
|
|
||||||
out long downloads.
|
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LoadedModel object.
|
A LoadedModel object.
|
||||||
"""
|
"""
|
||||||
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout)
|
model_path = self.install.download_and_cache_ckpt(source=source)
|
||||||
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)
|
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)
|
||||||
|
@ -496,8 +496,6 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
def load_ckpt_from_url(
|
def load_ckpt_from_url(
|
||||||
self,
|
self,
|
||||||
source: str | AnyHttpUrl,
|
source: str | AnyHttpUrl,
|
||||||
access_token: Optional[str] = None,
|
|
||||||
timeout: Optional[int] = 0,
|
|
||||||
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModel:
|
||||||
"""
|
"""
|
||||||
@ -515,17 +513,12 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
Args:
|
Args:
|
||||||
source: A URL or a string that can be converted in one. Repo_ids
|
source: A URL or a string that can be converted in one. Repo_ids
|
||||||
do not work here.
|
do not work here.
|
||||||
access_token: Optional access token for restricted resources.
|
|
||||||
timeout: Wait up to the indicated number of seconds before timing
|
|
||||||
out long downloads.
|
|
||||||
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LoadedModel object.
|
A LoadedModel object.
|
||||||
"""
|
"""
|
||||||
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(
|
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader)
|
||||||
source=source, access_token=access_token, timeout=timeout, loader=loader
|
|
||||||
)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,14 +2,19 @@
|
|||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic.networks import AnyHttpUrl
|
from pydantic.networks import AnyHttpUrl
|
||||||
from requests.sessions import Session
|
from requests.sessions import Session
|
||||||
from requests_testadapter import TestAdapter, TestSession
|
from requests_testadapter import TestAdapter, TestSession
|
||||||
|
|
||||||
|
from invokeai.app.services.config import get_config
|
||||||
|
from invokeai.app.services.config.config_default import URLRegexTokenPair
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
from invokeai.app.services.download import DownloadJob, DownloadJobStatus, DownloadQueueService
|
||||||
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
from tests.test_nodes import TestEventService
|
from tests.test_nodes import TestEventService
|
||||||
|
|
||||||
# Prevent pytest deprecation warnings
|
# Prevent pytest deprecation warnings
|
||||||
@ -34,6 +39,17 @@ def session() -> Session:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
sess.mount(
|
||||||
|
"http://www.huggingface.co/foo.txt",
|
||||||
|
TestAdapter(
|
||||||
|
content,
|
||||||
|
headers={
|
||||||
|
"Content-Length": len(content),
|
||||||
|
"Content-Disposition": 'filename="foo.safetensors"',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# here are some malformed URLs to test
|
# here are some malformed URLs to test
|
||||||
# missing the content length
|
# missing the content length
|
||||||
sess.mount(
|
sess.mount(
|
||||||
@ -205,3 +221,37 @@ def test_cancel(tmp_path: Path, session: Session) -> None:
|
|||||||
assert events[-1].event_name == "download_cancelled"
|
assert events[-1].event_name == "download_cancelled"
|
||||||
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
assert events[-1].payload["source"] == "http://www.civitai.com/models/12345"
|
||||||
queue.stop()
|
queue.stop()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def clear_config() -> Generator[None, None, None]:
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
get_config.cache_clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tokens(tmp_path: Path, session: Session):
|
||||||
|
with clear_config():
|
||||||
|
config = get_config()
|
||||||
|
config.remote_api_tokens = [URLRegexTokenPair(url_regex="civitai", token="cv_12345")]
|
||||||
|
queue = DownloadQueueService(requests_session=session)
|
||||||
|
queue.start()
|
||||||
|
# this one has an access token assigned
|
||||||
|
job1 = queue.download(
|
||||||
|
source=AnyHttpUrl("http://www.civitai.com/models/12345"),
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
# this one doesn't
|
||||||
|
job2 = queue.download(
|
||||||
|
source=AnyHttpUrl(
|
||||||
|
"http://www.huggingface.co/foo.txt",
|
||||||
|
),
|
||||||
|
dest=tmp_path,
|
||||||
|
)
|
||||||
|
queue.join()
|
||||||
|
# this token is defined in the temporary root invokeai.yaml
|
||||||
|
# see tests/backend/model_manager/data/invokeai_root/invokeai.yaml
|
||||||
|
assert job1.access_token == "cv_12345"
|
||||||
|
assert job2.access_token is None
|
||||||
|
queue.stop()
|
||||||
|
Loading…
Reference in New Issue
Block a user