move access token regex matching into download queue

This commit is contained in:
Lincoln Stein
2024-05-05 21:00:31 -04:00
parent 8e5e9b53d6
commit f211c95dbc
7 changed files with 69 additions and 29 deletions

View File

@ -93,7 +93,7 @@ class ApiDependencies:
conditioning = ObjectSerializerForwardCache(
ObjectSerializerDisk[ConditioningFieldData](output_folder / "conditioning", ephemeral=True)
)
download_queue_service = DownloadQueueService(event_bus=events)
download_queue_service = DownloadQueueService(app_config=configuration, event_bus=events)
model_images_service = ModelImageFileStorageDisk(model_images_folder / "model_images")
model_manager = ModelManagerService.build_model_manager(
app_config=configuration,

View File

@ -15,6 +15,7 @@ from pydantic.networks import AnyHttpUrl
from requests import HTTPError
from tqdm import tqdm
from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.util.misc import get_iso_timestamp
from invokeai.backend.util.logging import InvokeAILogger
@ -40,15 +41,18 @@ class DownloadQueueService(DownloadQueueServiceBase):
def __init__(
self,
max_parallel_dl: int = 5,
app_config: Optional[InvokeAIAppConfig] = None,
event_bus: Optional[EventServiceBase] = None,
requests_session: Optional[requests.sessions.Session] = None,
):
"""
Initialize DownloadQueue.
:param app_config: InvokeAIAppConfig object
:param max_parallel_dl: Number of simultaneous downloads allowed [5].
:param requests_session: Optional requests.sessions.Session object, for unit tests.
"""
self._app_config = app_config or get_config()
self._jobs: Dict[int, DownloadJob] = {}
self._next_job_id = 0
self._queue: PriorityQueue[DownloadJob] = PriorityQueue()
@ -139,7 +143,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
source=source,
dest=dest,
priority=priority,
access_token=access_token,
access_token=access_token or self._lookup_access_token(source),
)
self.submit_download_job(
job,
@ -333,6 +337,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
def _in_progress_path(self, path: Path) -> Path:
return path.with_name(path.name + ".downloading")
def _lookup_access_token(self, source: AnyHttpUrl) -> Optional[str]:
# Pull the token from config if it exists and matches the URL
print(self._app_config)
token = None
for pair in self._app_config.remote_api_tokens or []:
if re.search(pair.url_regex, str(source)):
token = pair.token
break
return token
def _signal_job_started(self, job: DownloadJob) -> None:
job.status = DownloadJobStatus.RUNNING
if job.on_start:

View File

@ -222,16 +222,9 @@ class ModelInstallService(ModelInstallServiceBase):
access_token=access_token,
)
elif re.match(r"^https?://[^/]+", source):
# Pull the token from config if it exists and matches the URL
_token = access_token
if _token is None:
for pair in self.app_config.remote_api_tokens or []:
if re.search(pair.url_regex, source):
_token = pair.token
break
source_obj = URLModelSource(
url=AnyHttpUrl(source),
access_token=_token,
access_token=access_token,
)
else:
raise ValueError(f"Unsupported model source: '{source}'")

View File

@ -75,8 +75,6 @@ class ModelManagerServiceBase(ABC):
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
@ -94,9 +92,6 @@ class ModelManagerServiceBase(ABC):
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:

View File

@ -106,8 +106,6 @@ class ModelManagerService(ModelManagerServiceBase):
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
@ -125,13 +123,10 @@ class ModelManagerService(ModelManagerServiceBase):
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
model_path = self.install.download_and_cache_ckpt(source=source, access_token=access_token, timeout=timeout)
model_path = self.install.download_and_cache_ckpt(source=source)
return self.load.load_ckpt_from_path(model_path=model_path, loader=loader)

View File

@ -496,8 +496,6 @@ class ModelsInterface(InvocationContextInterface):
def load_ckpt_from_url(
self,
source: str | AnyHttpUrl,
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
loader: Optional[Callable[[Path], Dict[str, torch.Tensor]]] = None,
) -> LoadedModel:
"""
@ -515,17 +513,12 @@ class ModelsInterface(InvocationContextInterface):
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
loader: A Callable that expects a Path and returns a Dict[str|int, Any]
Returns:
A LoadedModel object.
"""
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(
source=source, access_token=access_token, timeout=timeout, loader=loader
)
result: LoadedModel = self._services.model_manager.load_ckpt_from_url(source=source, loader=loader)
return result