From a2cc4047f9f026d48218145bd3ff925a7936949c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 12 Feb 2024 14:27:17 -0500 Subject: [PATCH] add a JIT download_and_cache() call to the model installer --- docs/contributing/MODEL_MANAGER.md | 40 +++++++++++++++ .../app/services/download/download_base.py | 13 +++++ .../app/services/download/download_default.py | 15 +++++- .../model_install/model_install_base.py | 34 ++++++++++++- .../model_install/model_install_default.py | 49 ++++++++++++++++++- .../convert_cache/convert_cache_default.py | 8 ++- 6 files changed, 154 insertions(+), 5 deletions(-) diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index 959b7f9733..b711c654de 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -792,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will return from the call if jobs aren't completed in the specified time. An argument of 0 (the default) will block indefinitely. +#### jobs = installer.wait_for_job(job, [timeout]) + +Like `wait_for_installs()`, but block until a specific job has +completed or errored, and then return the job. The optional `timeout` +argument will return from the call if the job doesn't complete in the +specified time. An argument of 0 (the default) will block +indefinitely. + #### jobs = installer.list_jobs() Return a list of all active and complete `ModelInstallJobs`. @@ -854,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally deletes the corresponding model weights file(s), regardless of whether they are inside or outside the InvokeAI models hierarchy. + +#### path = installer.download_and_cache(remote_source, [access_token], [timeout]) + +This utility routine will download the model file located at source, +cache it, and return the path to the cached file. It does not attempt +to determine the model type, probe its configuration values, or +register it with the models database. + +You may provide an access token if the remote source requires +authorization. The call will block indefinitely until the file is +completely downloaded, cancelled or raises an error of some sort. If +you provide a timeout (in seconds), the call will raise a +`TimeoutError` exception if the download hasn't completed in the +specified period. + +You may use this mechanism to request any type of file, not just a +model. The file will be stored in a subdirectory of +`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the +cache, its path will be returned without redownloading it. + +Be aware that the models cache is cleared of infrequently-used files +and directories at regular intervals when the size of the cache +exceeds the value specified in Invoke's `convert_cache` configuration +variable. + #### List[str]=installer.scan_directory(scan_dir: Path, install: bool) This method will recursively scan the directory indicated in @@ -1187,6 +1220,13 @@ queue or was not created by this queue. This method will block until all the active jobs in the queue have reached a terminal state (completed, errored or cancelled). +#### queue.wait_for_job(job, [timeout]) + +This method will block until the indicated job has reached a terminal +state (completed, errored or cancelled). If the optional timeout is +provided, the call will block for at most timeout seconds, and raise a +TimeoutError otherwise. + #### jobs = queue.list_jobs() This will return a list of all jobs, including ones that have not yet diff --git a/invokeai/app/services/download/download_base.py b/invokeai/app/services/download/download_base.py index f854f64f58..2ac13b825f 100644 --- a/invokeai/app/services/download/download_base.py +++ b/invokeai/app/services/download/download_base.py @@ -260,3 +260,16 @@ class DownloadQueueServiceBase(ABC): def join(self) -> None: """Wait until all jobs are off the queue.""" pass + + @abstractmethod + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Wait until the indicated download job has reached a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + pass diff --git a/invokeai/app/services/download/download_default.py b/invokeai/app/services/download/download_default.py index 7613c0893f..f740c50087 100644 --- a/invokeai/app/services/download/download_default.py +++ b/invokeai/app/services/download/download_default.py @@ -4,6 +4,7 @@ import os import re import threading +import time import traceback from pathlib import Path from queue import Empty, PriorityQueue @@ -52,6 +53,7 @@ class DownloadQueueService(DownloadQueueServiceBase): self._next_job_id = 0 self._queue = PriorityQueue() self._stop_event = threading.Event() + self._job_completed_event = threading.Event() self._worker_pool = set() self._lock = threading.Lock() self._logger = InvokeAILogger.get_logger("DownloadQueueService") @@ -188,6 +190,16 @@ class DownloadQueueService(DownloadQueueServiceBase): if not job.in_terminal_state: self.cancel_job(job) + def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._job_completed_event.wait(timeout=5): # in case we miss an event + self._job_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + def _start_workers(self, max_workers: int) -> None: """Start the requested number of worker threads.""" self._stop_event.clear() @@ -223,6 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase): finally: job.job_ended = get_iso_timestamp() + self._job_completed_event.set() # signal a change to terminal state self._queue.task_done() self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.") @@ -407,7 +420,7 @@ class DownloadQueueService(DownloadQueueServiceBase): # Example on_progress event handler to display a TQDM status bar # Activate with: -# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update +# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update)) class TqdmProgress(object): """TQDM-based progress bar object to use in on_progress handlers.""" diff --git a/invokeai/app/services/model_install/model_install_base.py b/invokeai/app/services/model_install/model_install_base.py index 943cdf1157..39ea8c4a0d 100644 --- a/invokeai/app/services/model_install/model_install_base.py +++ b/invokeai/app/services/model_install/model_install_base.py @@ -422,6 +422,18 @@ class ModelInstallServiceBase(ABC): def cancel_job(self, job: ModelInstallJob) -> None: """Cancel the indicated job.""" + @abstractmethod + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Wait for the indicated job to reach a terminal state. + + This will block until the indicated install job has completed, + been cancelled, or errored out. + + :param job: The job to wait on. + :param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if + the job hasn't completed within the indicated time. + """ + @abstractmethod def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: """ @@ -431,7 +443,8 @@ class ModelInstallServiceBase(ABC): completed, been cancelled, or errored out. :param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if - installs do not complete within the indicated time. + installs do not complete within the indicated time. A timeout of zero (the default) + will block indefinitely until the installs complete. """ @abstractmethod @@ -447,3 +460,22 @@ class ModelInstallServiceBase(ABC): @abstractmethod def sync_to_config(self) -> None: """Synchronize models on disk to those in the model record database.""" + + @abstractmethod + def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path: + """ + Download the model file located at source to the models cache and return its Path. + + :param source: A Url or a string that can be converted into one. + :param access_token: Optional access token to access restricted resources. + + The model file will be downloaded into the system-wide model cache + (`models/.cache`) if it isn't already there. Note that the model cache + is periodically cleared of infrequently-used entries when the model + converter runs. + + Note that this doesn't automaticallly install or register the model, but is + intended for use by nodes that need access to models that aren't directly + supported by InvokeAI. The downloading process takes advantage of the download queue + to avoid interrupting other operations. + """ diff --git a/invokeai/app/services/model_install/model_install_default.py b/invokeai/app/services/model_install/model_install_default.py index df73fcb8cb..414e300715 100644 --- a/invokeai/app/services/model_install/model_install_default.py +++ b/invokeai/app/services/model_install/model_install_default.py @@ -17,7 +17,7 @@ from pydantic.networks import AnyHttpUrl from requests import Session from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase +from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress from invokeai.app.services.events.events_base import EventServiceBase from invokeai.app.services.invoker import Invoker from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL @@ -87,6 +87,7 @@ class ModelInstallService(ModelInstallServiceBase): self._lock = threading.Lock() self._stop_event = threading.Event() self._downloads_changed_event = threading.Event() + self._install_completed_event = threading.Event() self._download_queue = download_queue self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {} self._running = False @@ -241,6 +242,17 @@ class ModelInstallService(ModelInstallServiceBase): assert isinstance(jobs[0], ModelInstallJob) return jobs[0] + def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob: + """Block until the indicated job has reached terminal state, or when timeout limit reached.""" + start = time.time() + while not job.in_terminal_state: + if self._install_completed_event.wait(timeout=5): # in case we miss an event + self._install_completed_event.clear() + if timeout > 0 and time.time() - start > timeout: + raise TimeoutError("Timeout exceeded") + return job + + # TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102 """Block until all installation jobs are done.""" start = time.time() @@ -248,7 +260,7 @@ class ModelInstallService(ModelInstallServiceBase): if self._downloads_changed_event.wait(timeout=5): # in case we miss an event self._downloads_changed_event.clear() if timeout > 0 and time.time() - start > timeout: - raise Exception("Timeout exceeded") + raise TimeoutError("Timeout exceeded") self._install_queue.join() return self._install_jobs @@ -302,6 +314,38 @@ class ModelInstallService(ModelInstallServiceBase): path.unlink() self.unregister(key) + def download_and_cache( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: int = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path.""" + model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32] + model_path = self._app_config.models_convert_cache_path / model_hash + + # We expect the cache directory to contain one and only one downloaded file. + # We don't know the file's name in advance, as it is set by the download + # content-disposition header. + if model_path.exists(): + contents = [x for x in model_path.iterdir() if x.is_file()] + if len(contents) > 0: + return contents[0] + + model_path.mkdir(parents=True, exist_ok=True) + job = self._download_queue.download( + source=AnyHttpUrl(str(source)), + dest=model_path, + access_token=access_token, + on_progress=TqdmProgress().update, + ) + self._download_queue.wait_for_job(job, timeout) + if job.complete: + assert job.download_path is not None + return job.download_path + else: + raise Exception(job.error) + # -------------------------------------------------------------------------------------------- # Internal functions that manage the installer threads # -------------------------------------------------------------------------------------------- @@ -365,6 +409,7 @@ class ModelInstallService(ModelInstallServiceBase): # if this is an install of a remote file, then clean up the temporary directory if job._install_tmpdir is not None: rmtree(job._install_tmpdir) + self._install_completed_event.set() self._install_queue.task_done() self._logger.info("Install thread exiting") diff --git a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py index 4c361258d9..84f4f76299 100644 --- a/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py +++ b/invokeai/backend/model_manager/load/convert_cache/convert_cache_default.py @@ -53,7 +53,13 @@ class ModelConvertCache(ModelConvertCacheBase): sentinel = path / config if sentinel.exists(): return sentinel.stat().st_atime - return 0.0 + + # no sentinel file found! - pick the most recent file in the directory + try: + atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True) + return atimes[0] + except IndexError: + return 0.0 # sort by last access time - least accessed files will be at the end lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)