From 598fe8101eba945e48b74b5d68d9ca4f98a49d6d Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sat, 9 Sep 2023 11:42:07 -0400 Subject: [PATCH] wire together download and install; now need to write install events --- .../backend/model_manager/download/queue.py | 1 - invokeai/backend/model_manager/install.py | 103 ++++++++++++++++-- 2 files changed, 96 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index 6182c01a75..5824c2598a 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -159,7 +159,6 @@ class DownloadQueue(DownloadQueueBase): self._lock.acquire() assert isinstance(self._jobs[job.id], DownloadJobBase) self._update_job_status(job, DownloadJobStatus.CANCELLED) - # del self._jobs[job.id] except (AssertionError, KeyError) as excp: raise UnknownJobIDException("Unrecognized job") from excp finally: diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 9851ce40a7..83bf282754 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -7,10 +7,12 @@ Typical usage: from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager import ModelInstall from invokeai.backend.model_manager.storage import ModelConfigStoreSQL + from invokeai.backend.model_manager.download import DownloadQueue config = InvokeAIAppConfig.get_config() store = ModelConfigStoreSQL(config.db_path) - installer = ModelInstall(store=store, config=config) + download = DownloadQueue() + installer = ModelInstall(store=store, config=config, download=download) # register config, don't move path id: str = installer.register_model('/path/to/model') @@ -18,6 +20,15 @@ Typical usage: # register config, and install model in `models` id: str = installer.install_model('/path/to/model') + # download some remote models and install them in the background + installer.download('stabilityai/stable-diffusion-2-1') + installer.download('https://civitai.com/api/download/models/154208') + installer.download('runwayml/stable-diffusion-v1-5') + + installed_ids = installer.wait_for_downloads() + id1 = installed_ids['stabilityai/stable-diffusion-2-1'] + id2 = installed_ids['https://civitai.com/api/download/models/154208'] + # unregister, don't delete installer.unregister(id) @@ -36,14 +47,17 @@ The following exceptions may be raised: DuplicateModelException UnknownModelTypeException """ +import tempfile from abc import ABC, abstractmethod from pathlib import Path from shutil import rmtree -from typing import Optional, List, Union +from typing import Optional, List, Union, Dict +from pydantic.networks import AnyHttpUrl from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.util.logging import InvokeAILogger from .search import ModelSearch from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException +from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase from .hash import FastModelHash from .probe import ModelProbe, ModelProbeInfo, InvalidModelException from .config import ( @@ -100,6 +114,40 @@ class ModelInstallBase(ABC): """ pass + @abstractmethod + def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase: + """ + Download and install the model located at remote site. + + This will download the model located at `source`, + probe it, and install it into the models directory. + This call is executed asynchronously in a separate + thread, and the returned object is a + invokeai.backend.model_manager.download.DownloadJobBase + object which can be interrogated to get the status of + the download and install process. Call our `wait_for_downloads()` + method to wait for all downloads to complete. + + :param source: Either a URL or a HuggingFace repo_id. + :returns queue: DownloadQueueBase object. + """ + pass + + @abstractmethod + def wait_for_downloads(self) -> Dict[str, str]: + """ + Wait for all pending downloads to complete. + + This will block until all pending downloads have + completed, been cancelled, or errored out. It will + block indefinitely if one or more jobs are in the + paused state. + + It will return a dict that maps the source model + URL or repo_id to the ID of the installed model. + """ + pass + @abstractmethod def unregister(self, id: str): """ @@ -167,6 +215,9 @@ class ModelInstall(ModelInstallBase): _config: InvokeAIAppConfig _logger: InvokeAILogger _store: ModelConfigStore + _download_queue: DownloadQueueBase + _async_installs: Dict[str, str] + _tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads _legacy_configs = { BaseModelType.StableDiffusion1: { @@ -196,12 +247,14 @@ class ModelInstall(ModelInstallBase): store: Optional[ModelConfigStore] = None, config: Optional[InvokeAIAppConfig] = None, logger: Optional[InvokeAILogger] = None, + download: Optional[DownloadQueueBase] = None, ): # noqa D107 - use base class docstrings self._config = config or InvokeAIAppConfig.get_config() self._logger = logger or InvokeAILogger.getLogger() - if store is None: - store = ModelConfigStoreYAML(config.model_conf_path) - self._store = store + self._store = store or ModelConfigStoreYAML(self._config.model_conf_path) + self._download_queue = download or DownloadQueue() + self._async_installs = dict() + self._tmpdir = None def register(self, model_path: Union[Path, str]) -> str: # noqa D102 model_path = Path(model_path) @@ -232,7 +285,8 @@ class ModelInstall(ModelInstallBase): def install(self, model_path: Union[Path, str]) -> str: # noqa D102 model_path = Path(model_path) info: ModelProbeInfo = ModelProbe.probe(model_path) - dest_path = self._config.models_path / info.base_model.value / info.model_type.value / model_path.name + dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name + dest_path.parent.mkdir(parents=True, exist_ok=True) # if path already exists then we jigger the name to make it unique counter: int = 1 @@ -240,7 +294,7 @@ class ModelInstall(ModelInstallBase): dest_path = dest_path.with_stem(dest_path.stem + f"_{counter:02d}") counter += 1 - self._register( + return self._register( model_path.replace(dest_path), info, ) @@ -253,6 +307,41 @@ class ModelInstall(ModelInstallBase): rmtree(model.path) self.unregister(id) + def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase: # noqa D102 + # choose a temporary directory inside the models directory + models_dir = self._config.models_path + queue = self._download_queue + self._async_installs[source] = None + + def complete_installation(job: DownloadJobBase): + self._logger.info(f"{job.source}: {job.status} filename={job.destination}({job.bytes}/{job.total_bytes})") + if job.status == "completed": + id = self.install(job.destination) + info = self._store.get_model(id) + info.description = f"Downloaded model {info.name}" + info.source_url = str(job.source) + self._store.update_model(id, info) + self._async_installs[job.source] = id + jobs = queue.list_jobs() + if len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]: + self._tmpdir = None + + # note - this is probably not going to work. The tmpdir + # will be deleted before the job actually runs. + # Better to do the cleanup in the callback + self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) + return queue.create_download_job( + source=source, + destdir=self._tmpdir.name, + event_handlers=[complete_installation] + ) + + def wait_for_downloads(self) -> Dict[str, str]: # noqa D102 + self._download_queue.join() + id_map = self._async_installs + self._async_installs = dict() + return id_map + def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102 callback = self._scan_install if install else self._scan_register search = ModelSearch(on_model_found=callback)