mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add a JIT download_and_cache() call to the model installer
This commit is contained in:
committed by
psychedelicious
parent
4027e845d4
commit
a2cc4047f9
@ -260,3 +260,16 @@ class DownloadQueueServiceBase(ABC):
|
||||
def join(self) -> None:
|
||||
"""Wait until all jobs are off the queue."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
"""Wait until the indicated download job has reached a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
been cancelled, or errored out.
|
||||
|
||||
:param job: The job to wait on.
|
||||
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
||||
the job hasn't completed within the indicated time.
|
||||
"""
|
||||
pass
|
||||
|
@ -4,6 +4,7 @@
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
from queue import Empty, PriorityQueue
|
||||
@ -52,6 +53,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
self._next_job_id = 0
|
||||
self._queue = PriorityQueue()
|
||||
self._stop_event = threading.Event()
|
||||
self._job_completed_event = threading.Event()
|
||||
self._worker_pool = set()
|
||||
self._lock = threading.Lock()
|
||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||
@ -188,6 +190,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
if not job.in_terminal_state:
|
||||
self.cancel_job(job)
|
||||
|
||||
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._job_completed_event.wait(timeout=5): # in case we miss an event
|
||||
self._job_completed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
|
||||
def _start_workers(self, max_workers: int) -> None:
|
||||
"""Start the requested number of worker threads."""
|
||||
self._stop_event.clear()
|
||||
@ -223,6 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
finally:
|
||||
job.job_ended = get_iso_timestamp()
|
||||
self._job_completed_event.set() # signal a change to terminal state
|
||||
self._queue.task_done()
|
||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||
|
||||
@ -407,7 +420,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
||||
|
||||
# Example on_progress event handler to display a TQDM status bar
|
||||
# Activate with:
|
||||
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
|
||||
# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update))
|
||||
class TqdmProgress(object):
|
||||
"""TQDM-based progress bar object to use in on_progress handlers."""
|
||||
|
||||
|
@ -422,6 +422,18 @@ class ModelInstallServiceBase(ABC):
|
||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||
"""Cancel the indicated job."""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
||||
"""Wait for the indicated job to reach a terminal state.
|
||||
|
||||
This will block until the indicated install job has completed,
|
||||
been cancelled, or errored out.
|
||||
|
||||
:param job: The job to wait on.
|
||||
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
||||
the job hasn't completed within the indicated time.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
|
||||
"""
|
||||
@ -431,7 +443,8 @@ class ModelInstallServiceBase(ABC):
|
||||
completed, been cancelled, or errored out.
|
||||
|
||||
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
|
||||
installs do not complete within the indicated time.
|
||||
installs do not complete within the indicated time. A timeout of zero (the default)
|
||||
will block indefinitely until the installs complete.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -447,3 +460,22 @@ class ModelInstallServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def sync_to_config(self) -> None:
|
||||
"""Synchronize models on disk to those in the model record database."""
|
||||
|
||||
@abstractmethod
|
||||
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||
"""
|
||||
Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
:param source: A Url or a string that can be converted into one.
|
||||
:param access_token: Optional access token to access restricted resources.
|
||||
|
||||
The model file will be downloaded into the system-wide model cache
|
||||
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||
is periodically cleared of infrequently-used entries when the model
|
||||
converter runs.
|
||||
|
||||
Note that this doesn't automaticallly install or register the model, but is
|
||||
intended for use by nodes that need access to models that aren't directly
|
||||
supported by InvokeAI. The downloading process takes advantage of the download queue
|
||||
to avoid interrupting other operations.
|
||||
"""
|
||||
|
@ -17,7 +17,7 @@ from pydantic.networks import AnyHttpUrl
|
||||
from requests import Session
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||
from invokeai.app.services.events.events_base import EventServiceBase
|
||||
from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
|
||||
@ -87,6 +87,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
self._lock = threading.Lock()
|
||||
self._stop_event = threading.Event()
|
||||
self._downloads_changed_event = threading.Event()
|
||||
self._install_completed_event = threading.Event()
|
||||
self._download_queue = download_queue
|
||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||
self._running = False
|
||||
@ -241,6 +242,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
assert isinstance(jobs[0], ModelInstallJob)
|
||||
return jobs[0]
|
||||
|
||||
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
||||
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||
start = time.time()
|
||||
while not job.in_terminal_state:
|
||||
if self._install_completed_event.wait(timeout=5): # in case we miss an event
|
||||
self._install_completed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
return job
|
||||
|
||||
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
|
||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||
"""Block until all installation jobs are done."""
|
||||
start = time.time()
|
||||
@ -248,7 +260,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
|
||||
self._downloads_changed_event.clear()
|
||||
if timeout > 0 and time.time() - start > timeout:
|
||||
raise Exception("Timeout exceeded")
|
||||
raise TimeoutError("Timeout exceeded")
|
||||
self._install_queue.join()
|
||||
return self._install_jobs
|
||||
|
||||
@ -302,6 +314,38 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
path.unlink()
|
||||
self.unregister(key)
|
||||
|
||||
def download_and_cache(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
access_token: Optional[str] = None,
|
||||
timeout: int = 0,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path."""
|
||||
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
||||
model_path = self._app_config.models_convert_cache_path / model_hash
|
||||
|
||||
# We expect the cache directory to contain one and only one downloaded file.
|
||||
# We don't know the file's name in advance, as it is set by the download
|
||||
# content-disposition header.
|
||||
if model_path.exists():
|
||||
contents = [x for x in model_path.iterdir() if x.is_file()]
|
||||
if len(contents) > 0:
|
||||
return contents[0]
|
||||
|
||||
model_path.mkdir(parents=True, exist_ok=True)
|
||||
job = self._download_queue.download(
|
||||
source=AnyHttpUrl(str(source)),
|
||||
dest=model_path,
|
||||
access_token=access_token,
|
||||
on_progress=TqdmProgress().update,
|
||||
)
|
||||
self._download_queue.wait_for_job(job, timeout)
|
||||
if job.complete:
|
||||
assert job.download_path is not None
|
||||
return job.download_path
|
||||
else:
|
||||
raise Exception(job.error)
|
||||
|
||||
# --------------------------------------------------------------------------------------------
|
||||
# Internal functions that manage the installer threads
|
||||
# --------------------------------------------------------------------------------------------
|
||||
@ -365,6 +409,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
||||
# if this is an install of a remote file, then clean up the temporary directory
|
||||
if job._install_tmpdir is not None:
|
||||
rmtree(job._install_tmpdir)
|
||||
self._install_completed_event.set()
|
||||
self._install_queue.task_done()
|
||||
|
||||
self._logger.info("Install thread exiting")
|
||||
|
Reference in New Issue
Block a user