mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add a JIT download_and_cache() call to the model installer
This commit is contained in:
parent
4027e845d4
commit
a2cc4047f9
@ -792,6 +792,14 @@ returns a list of completed jobs. The optional `timeout` argument will
|
|||||||
return from the call if jobs aren't completed in the specified
|
return from the call if jobs aren't completed in the specified
|
||||||
time. An argument of 0 (the default) will block indefinitely.
|
time. An argument of 0 (the default) will block indefinitely.
|
||||||
|
|
||||||
|
#### jobs = installer.wait_for_job(job, [timeout])
|
||||||
|
|
||||||
|
Like `wait_for_installs()`, but block until a specific job has
|
||||||
|
completed or errored, and then return the job. The optional `timeout`
|
||||||
|
argument will return from the call if the job doesn't complete in the
|
||||||
|
specified time. An argument of 0 (the default) will block
|
||||||
|
indefinitely.
|
||||||
|
|
||||||
#### jobs = installer.list_jobs()
|
#### jobs = installer.list_jobs()
|
||||||
|
|
||||||
Return a list of all active and complete `ModelInstallJobs`.
|
Return a list of all active and complete `ModelInstallJobs`.
|
||||||
@ -854,6 +862,31 @@ This method is similar to `unregister()`, but also unconditionally
|
|||||||
deletes the corresponding model weights file(s), regardless of whether
|
deletes the corresponding model weights file(s), regardless of whether
|
||||||
they are inside or outside the InvokeAI models hierarchy.
|
they are inside or outside the InvokeAI models hierarchy.
|
||||||
|
|
||||||
|
|
||||||
|
#### path = installer.download_and_cache(remote_source, [access_token], [timeout])
|
||||||
|
|
||||||
|
This utility routine will download the model file located at source,
|
||||||
|
cache it, and return the path to the cached file. It does not attempt
|
||||||
|
to determine the model type, probe its configuration values, or
|
||||||
|
register it with the models database.
|
||||||
|
|
||||||
|
You may provide an access token if the remote source requires
|
||||||
|
authorization. The call will block indefinitely until the file is
|
||||||
|
completely downloaded, cancelled or raises an error of some sort. If
|
||||||
|
you provide a timeout (in seconds), the call will raise a
|
||||||
|
`TimeoutError` exception if the download hasn't completed in the
|
||||||
|
specified period.
|
||||||
|
|
||||||
|
You may use this mechanism to request any type of file, not just a
|
||||||
|
model. The file will be stored in a subdirectory of
|
||||||
|
`INVOKEAI_ROOT/models/.cache`. If the requested file is found in the
|
||||||
|
cache, its path will be returned without redownloading it.
|
||||||
|
|
||||||
|
Be aware that the models cache is cleared of infrequently-used files
|
||||||
|
and directories at regular intervals when the size of the cache
|
||||||
|
exceeds the value specified in Invoke's `convert_cache` configuration
|
||||||
|
variable.
|
||||||
|
|
||||||
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
|
#### List[str]=installer.scan_directory(scan_dir: Path, install: bool)
|
||||||
|
|
||||||
This method will recursively scan the directory indicated in
|
This method will recursively scan the directory indicated in
|
||||||
@ -1187,6 +1220,13 @@ queue or was not created by this queue.
|
|||||||
This method will block until all the active jobs in the queue have
|
This method will block until all the active jobs in the queue have
|
||||||
reached a terminal state (completed, errored or cancelled).
|
reached a terminal state (completed, errored or cancelled).
|
||||||
|
|
||||||
|
#### queue.wait_for_job(job, [timeout])
|
||||||
|
|
||||||
|
This method will block until the indicated job has reached a terminal
|
||||||
|
state (completed, errored or cancelled). If the optional timeout is
|
||||||
|
provided, the call will block for at most timeout seconds, and raise a
|
||||||
|
TimeoutError otherwise.
|
||||||
|
|
||||||
#### jobs = queue.list_jobs()
|
#### jobs = queue.list_jobs()
|
||||||
|
|
||||||
This will return a list of all jobs, including ones that have not yet
|
This will return a list of all jobs, including ones that have not yet
|
||||||
|
@ -260,3 +260,16 @@ class DownloadQueueServiceBase(ABC):
|
|||||||
def join(self) -> None:
|
def join(self) -> None:
|
||||||
"""Wait until all jobs are off the queue."""
|
"""Wait until all jobs are off the queue."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||||
|
"""Wait until the indicated download job has reached a terminal state.
|
||||||
|
|
||||||
|
This will block until the indicated install job has completed,
|
||||||
|
been cancelled, or errored out.
|
||||||
|
|
||||||
|
:param job: The job to wait on.
|
||||||
|
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
||||||
|
the job hasn't completed within the indicated time.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
@ -4,6 +4,7 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Empty, PriorityQueue
|
from queue import Empty, PriorityQueue
|
||||||
@ -52,6 +53,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
self._next_job_id = 0
|
self._next_job_id = 0
|
||||||
self._queue = PriorityQueue()
|
self._queue = PriorityQueue()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
|
self._job_completed_event = threading.Event()
|
||||||
self._worker_pool = set()
|
self._worker_pool = set()
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
self._logger = InvokeAILogger.get_logger("DownloadQueueService")
|
||||||
@ -188,6 +190,16 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
if not job.in_terminal_state:
|
if not job.in_terminal_state:
|
||||||
self.cancel_job(job)
|
self.cancel_job(job)
|
||||||
|
|
||||||
|
def wait_for_job(self, job: DownloadJob, timeout: int = 0) -> DownloadJob:
|
||||||
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||||
|
start = time.time()
|
||||||
|
while not job.in_terminal_state:
|
||||||
|
if self._job_completed_event.wait(timeout=5): # in case we miss an event
|
||||||
|
self._job_completed_event.clear()
|
||||||
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
|
raise TimeoutError("Timeout exceeded")
|
||||||
|
return job
|
||||||
|
|
||||||
def _start_workers(self, max_workers: int) -> None:
|
def _start_workers(self, max_workers: int) -> None:
|
||||||
"""Start the requested number of worker threads."""
|
"""Start the requested number of worker threads."""
|
||||||
self._stop_event.clear()
|
self._stop_event.clear()
|
||||||
@ -223,6 +235,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
|
|
||||||
finally:
|
finally:
|
||||||
job.job_ended = get_iso_timestamp()
|
job.job_ended = get_iso_timestamp()
|
||||||
|
self._job_completed_event.set() # signal a change to terminal state
|
||||||
self._queue.task_done()
|
self._queue.task_done()
|
||||||
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
self._logger.debug(f"Download queue worker thread {threading.current_thread().name} exiting.")
|
||||||
|
|
||||||
@ -407,7 +420,7 @@ class DownloadQueueService(DownloadQueueServiceBase):
|
|||||||
|
|
||||||
# Example on_progress event handler to display a TQDM status bar
|
# Example on_progress event handler to display a TQDM status bar
|
||||||
# Activate with:
|
# Activate with:
|
||||||
# download_service.download('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().job_update
|
# download_service.download(DownloadJob('http://foo.bar/baz', '/tmp', on_progress=TqdmProgress().update))
|
||||||
class TqdmProgress(object):
|
class TqdmProgress(object):
|
||||||
"""TQDM-based progress bar object to use in on_progress handlers."""
|
"""TQDM-based progress bar object to use in on_progress handlers."""
|
||||||
|
|
||||||
|
@ -422,6 +422,18 @@ class ModelInstallServiceBase(ABC):
|
|||||||
def cancel_job(self, job: ModelInstallJob) -> None:
|
def cancel_job(self, job: ModelInstallJob) -> None:
|
||||||
"""Cancel the indicated job."""
|
"""Cancel the indicated job."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
||||||
|
"""Wait for the indicated job to reach a terminal state.
|
||||||
|
|
||||||
|
This will block until the indicated install job has completed,
|
||||||
|
been cancelled, or errored out.
|
||||||
|
|
||||||
|
:param job: The job to wait on.
|
||||||
|
:param timeout: Wait up to indicated number of seconds. Raise a TimeoutError if
|
||||||
|
the job hasn't completed within the indicated time.
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
|
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]:
|
||||||
"""
|
"""
|
||||||
@ -431,7 +443,8 @@ class ModelInstallServiceBase(ABC):
|
|||||||
completed, been cancelled, or errored out.
|
completed, been cancelled, or errored out.
|
||||||
|
|
||||||
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
|
:param timeout: Wait up to indicated number of seconds. Raise an Exception('timeout') if
|
||||||
installs do not complete within the indicated time.
|
installs do not complete within the indicated time. A timeout of zero (the default)
|
||||||
|
will block indefinitely until the installs complete.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@ -447,3 +460,22 @@ class ModelInstallServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def sync_to_config(self) -> None:
|
def sync_to_config(self) -> None:
|
||||||
"""Synchronize models on disk to those in the model record database."""
|
"""Synchronize models on disk to those in the model record database."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def download_and_cache(self, source: Union[str, AnyHttpUrl], access_token: Optional[str] = None) -> Path:
|
||||||
|
"""
|
||||||
|
Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
|
:param source: A Url or a string that can be converted into one.
|
||||||
|
:param access_token: Optional access token to access restricted resources.
|
||||||
|
|
||||||
|
The model file will be downloaded into the system-wide model cache
|
||||||
|
(`models/.cache`) if it isn't already there. Note that the model cache
|
||||||
|
is periodically cleared of infrequently-used entries when the model
|
||||||
|
converter runs.
|
||||||
|
|
||||||
|
Note that this doesn't automaticallly install or register the model, but is
|
||||||
|
intended for use by nodes that need access to models that aren't directly
|
||||||
|
supported by InvokeAI. The downloading process takes advantage of the download queue
|
||||||
|
to avoid interrupting other operations.
|
||||||
|
"""
|
||||||
|
@ -17,7 +17,7 @@ from pydantic.networks import AnyHttpUrl
|
|||||||
from requests import Session
|
from requests import Session
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase
|
from invokeai.app.services.download import DownloadJob, DownloadQueueServiceBase, TqdmProgress
|
||||||
from invokeai.app.services.events.events_base import EventServiceBase
|
from invokeai.app.services.events.events_base import EventServiceBase
|
||||||
from invokeai.app.services.invoker import Invoker
|
from invokeai.app.services.invoker import Invoker
|
||||||
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
|
from invokeai.app.services.model_records import DuplicateModelException, ModelRecordServiceBase, ModelRecordServiceSQL
|
||||||
@ -87,6 +87,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._stop_event = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._downloads_changed_event = threading.Event()
|
self._downloads_changed_event = threading.Event()
|
||||||
|
self._install_completed_event = threading.Event()
|
||||||
self._download_queue = download_queue
|
self._download_queue = download_queue
|
||||||
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
self._download_cache: Dict[AnyHttpUrl, ModelInstallJob] = {}
|
||||||
self._running = False
|
self._running = False
|
||||||
@ -241,6 +242,17 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
assert isinstance(jobs[0], ModelInstallJob)
|
assert isinstance(jobs[0], ModelInstallJob)
|
||||||
return jobs[0]
|
return jobs[0]
|
||||||
|
|
||||||
|
def wait_for_job(self, job: ModelInstallJob, timeout: int = 0) -> ModelInstallJob:
|
||||||
|
"""Block until the indicated job has reached terminal state, or when timeout limit reached."""
|
||||||
|
start = time.time()
|
||||||
|
while not job.in_terminal_state:
|
||||||
|
if self._install_completed_event.wait(timeout=5): # in case we miss an event
|
||||||
|
self._install_completed_event.clear()
|
||||||
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
|
raise TimeoutError("Timeout exceeded")
|
||||||
|
return job
|
||||||
|
|
||||||
|
# TODO: Better name? Maybe wait_for_jobs()? Maybe too easily confused with above
|
||||||
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
def wait_for_installs(self, timeout: int = 0) -> List[ModelInstallJob]: # noqa D102
|
||||||
"""Block until all installation jobs are done."""
|
"""Block until all installation jobs are done."""
|
||||||
start = time.time()
|
start = time.time()
|
||||||
@ -248,7 +260,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
|
if self._downloads_changed_event.wait(timeout=5): # in case we miss an event
|
||||||
self._downloads_changed_event.clear()
|
self._downloads_changed_event.clear()
|
||||||
if timeout > 0 and time.time() - start > timeout:
|
if timeout > 0 and time.time() - start > timeout:
|
||||||
raise Exception("Timeout exceeded")
|
raise TimeoutError("Timeout exceeded")
|
||||||
self._install_queue.join()
|
self._install_queue.join()
|
||||||
return self._install_jobs
|
return self._install_jobs
|
||||||
|
|
||||||
@ -302,6 +314,38 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
path.unlink()
|
path.unlink()
|
||||||
self.unregister(key)
|
self.unregister(key)
|
||||||
|
|
||||||
|
def download_and_cache(
|
||||||
|
self,
|
||||||
|
source: Union[str, AnyHttpUrl],
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
timeout: int = 0,
|
||||||
|
) -> Path:
|
||||||
|
"""Download the model file located at source to the models cache and return its Path."""
|
||||||
|
model_hash = sha256(str(source).encode("utf-8")).hexdigest()[0:32]
|
||||||
|
model_path = self._app_config.models_convert_cache_path / model_hash
|
||||||
|
|
||||||
|
# We expect the cache directory to contain one and only one downloaded file.
|
||||||
|
# We don't know the file's name in advance, as it is set by the download
|
||||||
|
# content-disposition header.
|
||||||
|
if model_path.exists():
|
||||||
|
contents = [x for x in model_path.iterdir() if x.is_file()]
|
||||||
|
if len(contents) > 0:
|
||||||
|
return contents[0]
|
||||||
|
|
||||||
|
model_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
job = self._download_queue.download(
|
||||||
|
source=AnyHttpUrl(str(source)),
|
||||||
|
dest=model_path,
|
||||||
|
access_token=access_token,
|
||||||
|
on_progress=TqdmProgress().update,
|
||||||
|
)
|
||||||
|
self._download_queue.wait_for_job(job, timeout)
|
||||||
|
if job.complete:
|
||||||
|
assert job.download_path is not None
|
||||||
|
return job.download_path
|
||||||
|
else:
|
||||||
|
raise Exception(job.error)
|
||||||
|
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
# Internal functions that manage the installer threads
|
# Internal functions that manage the installer threads
|
||||||
# --------------------------------------------------------------------------------------------
|
# --------------------------------------------------------------------------------------------
|
||||||
@ -365,6 +409,7 @@ class ModelInstallService(ModelInstallServiceBase):
|
|||||||
# if this is an install of a remote file, then clean up the temporary directory
|
# if this is an install of a remote file, then clean up the temporary directory
|
||||||
if job._install_tmpdir is not None:
|
if job._install_tmpdir is not None:
|
||||||
rmtree(job._install_tmpdir)
|
rmtree(job._install_tmpdir)
|
||||||
|
self._install_completed_event.set()
|
||||||
self._install_queue.task_done()
|
self._install_queue.task_done()
|
||||||
|
|
||||||
self._logger.info("Install thread exiting")
|
self._logger.info("Install thread exiting")
|
||||||
|
@ -53,7 +53,13 @@ class ModelConvertCache(ModelConvertCacheBase):
|
|||||||
sentinel = path / config
|
sentinel = path / config
|
||||||
if sentinel.exists():
|
if sentinel.exists():
|
||||||
return sentinel.stat().st_atime
|
return sentinel.stat().st_atime
|
||||||
return 0.0
|
|
||||||
|
# no sentinel file found! - pick the most recent file in the directory
|
||||||
|
try:
|
||||||
|
atimes = sorted([x.stat().st_atime for x in path.iterdir() if x.is_file()], reverse=True)
|
||||||
|
return atimes[0]
|
||||||
|
except IndexError:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
# sort by last access time - least accessed files will be at the end
|
# sort by last access time - least accessed files will be at the end
|
||||||
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
|
lru_models = sorted(self._cache_path.iterdir(), key=by_atime, reverse=True)
|
||||||
|
Loading…
Reference in New Issue
Block a user