add a JIT download_and_cache() call to the model installer

This commit is contained in:
Lincoln Stein
2024-02-12 14:27:17 -05:00
committed by psychedelicious
parent 4027e845d4
commit a2cc4047f9
6 changed files with 154 additions and 5 deletions

View File

@ -260,3 +260,16 @@ class DownloadQueueServiceBase(ABC):
def join(self) -> None:
"""Wait until all jobs are off the queue."""
pass
@abstractmethod
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
"""Wait until the indicated download job has reached a terminal state.
This will block until the indicated install job has completed,
been cancelled, or errored out.
:param job: The job to wait on.
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
the job hasn't completed within the indicated time.
"""
pass

View File

@ -4,6 +4,7 @@
import os
import re
import threading
import time
import traceback
from pathlib import Path
from queue import Empty, PriorityQueue
@ -52,6 +53,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
self._next_job_id = 0
self._queue = PriorityQueue()
self._stop_event = threading.Event()
self._job_completed_event = threading.Event()
self._worker_pool = set()
self._lock = threading.Lock()
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
@ -188,6 +190,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
if not job.in_terminal_state:
self.cancel_job(job)
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._job_completed_event.wait(timeout=5): # in case we miss an event
self._job_completed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
def _start_workers(self, max_workers: int) -> None:
"""Start the requested number of worker threads."""
self._stop_event.clear()
@ -223,6 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
finally:
job.job_ended = get_iso_timestamp()
self._job_completed_event.set() # signal a change to terminal state
self._queue.task_done()
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
@ -407,7 +420,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
# Example on_progress event handler to display a TQDM status bar
# Activate with:
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update))
class TqdmProgress(object):
"""TQDM-based progress bar object to use in on_progress handlers."""

View File

@ -422,6 +422,18 @@ class ModelInstallServiceBase(ABC):
def cancel_job(self, job: ModelInstallJob) -> None:
"""Cancel the indicated job."""
@abstractmethod
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
"""Wait for the indicated job to reach a terminal state.
This will block until the indicated install job has completed,
been cancelled, or errored out.
:param job: The job to wait on.
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
the job hasn't completed within the indicated time.
"""
@abstractmethod
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
"""
@ -431,7 +443,8 @@ class ModelInstallServiceBase(ABC):
completed, been cancelled, or errored out.
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
installs do not complete within the indicated time.
installs do not complete within the indicated time. A timeout of zero (the default)
will block indefinitely until the installs complete.
"""
@abstractmethod
@ -447,3 +460,22 @@ class ModelInstallServiceBase(ABC):
@abstractmethod
def sync_to_config(self) -> None:
"""Synchronize models on disk to those in the model record database."""
@abstractmethod
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
"""
Download the model file located at source to the models cache and return its Path.
:param source: A Url or a string that can be converted into one.
:param access_token: Optional access token to access restricted resources.
The model file will be downloaded into the system-wide model cache
(`models/.cache`) if it isn't already there. Note that the model cache
is periodically cleared of infrequently-used entries when the model
converter runs.
Note that this doesn't automaticallly install or register the model, but is
intended for use by nodes that need access to models that aren't directly
supported by InvokeAI. The downloading process takes advantage of the download queue
to avoid interrupting other operations.
"""

View File

@ -17,7 +17,7 @@ from pydantic.networks import AnyHttpUrl
from requests import Session
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
from invokeai.app.services.events.events_base import EventServiceBase
from invokeai.app.services.invoker import Invoker
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
@ -87,6 +87,7 @@ class ModelInstallService(ModelInstallServiceBase):
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._downloads_changed_event = threading.Event()
self._install_completed_event = threading.Event()
self._download_queue = download_queue
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
self._running = False
@ -241,6 +242,17 @@ class ModelInstallService(ModelInstallServiceBase):
assert isinstance(jobs[0], ModelInstallJob)
return jobs[0]
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
start = time.time()
while not job.in_terminal_state:
if self._install_completed_event.wait(timeout=5): # in case we miss an event
self._install_completed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise TimeoutError("Timeout exceeded")
return job
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
"""Block until all installation jobs are done."""
start = time.time()
@ -248,7 +260,7 @@ class ModelInstallService(ModelInstallServiceBase):
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
self._downloads_changed_event.clear()
if timeout > 0 and time.time() - start > timeout:
raise Exception("Timeout exceeded")
raise TimeoutError("Timeout exceeded")
self._install_queue.join()
return self._install_jobs
@ -302,6 +314,38 @@ class ModelInstallService(ModelInstallServiceBase):
path.unlink()
self.unregister(key)
def download_and_cache(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: int = 0,
) -> Path:
"""Download the model file located at source to the models cache and return its Path."""
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
model_path = self._app_config.models_convert_cache_path / model_hash
# We expect the cache directory to contain one and only one downloaded file.
# We don't know the file's name in advance, as it is set by the download
# content-disposition header.
if model_path.exists():
contents = [x for x in model_path.iterdir() if x.is_file()]
if len(contents) > 0:
return contents[0]
model_path.mkdir(parents=True, exist_ok=True)
job = self._download_queue.download(
source=AnyHttpUrl(str(source)),
dest=model_path,
access_token=access_token,
on_progress=TqdmProgress().update,
)
self._download_queue.wait_for_job(job, timeout)
if job.complete:
assert job.download_path is not None
return job.download_path
else:
raise Exception(job.error)
# --------------------------------------------------------------------------------------------
# Internal functions that manage the installer threads
# --------------------------------------------------------------------------------------------
@ -365,6 +409,7 @@ class ModelInstallService(ModelInstallServiceBase):
# if this is an install of a remote file, then clean up the temporary directory
if job._install_tmpdir is not None:
rmtree(job._install_tmpdir)
self._install_completed_event.set()
self._install_queue.task_done()
self._logger.info("Install thread exiting")