wire together download and install; now need to write install events

This commit is contained in:
Lincoln Stein
2023-09-09 11:42:07 -04:00
parent b7ca983f9c
commit 598fe8101e
2 changed files with 96 additions and 8 deletions

View File

@ -159,7 +159,6 @@ class DownloadQueue(DownloadQueueBase):
self._lock.acquire()
assert isinstance(self._jobs[job.id], DownloadJobBase)
self._update_job_status(job, DownloadJobStatus.CANCELLED)
# del self._jobs[job.id]
except (AssertionError, KeyError) as excp:
raise UnknownJobIDException("Unrecognized job") from excp
finally:

View File

@ -7,10 +7,12 @@ Typical usage:
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import ModelInstall
from invokeai.backend.model_manager.storage import ModelConfigStoreSQL
from invokeai.backend.model_manager.download import DownloadQueue
config = InvokeAIAppConfig.get_config()
store = ModelConfigStoreSQL(config.db_path)
installer = ModelInstall(store=store, config=config)
download = DownloadQueue()
installer = ModelInstall(store=store, config=config, download=download)
# register config, don't move path
id: str = installer.register_model('/path/to/model')
@ -18,6 +20,15 @@ Typical usage:
# register config, and install model in `models`
id: str = installer.install_model('/path/to/model')
# download some remote models and install them in the background
installer.download('stabilityai/stable-diffusion-2-1')
installer.download('https://civitai.com/api/download/models/154208')
installer.download('runwayml/stable-diffusion-v1-5')
installed_ids = installer.wait_for_downloads()
id1 = installed_ids['stabilityai/stable-diffusion-2-1']
id2 = installed_ids['https://civitai.com/api/download/models/154208']
# unregister, don't delete
installer.unregister(id)
@ -36,14 +47,17 @@ The following exceptions may be raised:
DuplicateModelException
UnknownModelTypeException
"""
import tempfile
from abc import ABC, abstractmethod
from pathlib import Path
from shutil import rmtree
from typing import Optional, List, Union
from typing import Optional, List, Union, Dict
from pydantic.networks import AnyHttpUrl
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .search import ModelSearch
from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase
from .hash import FastModelHash
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
from .config import (
@ -100,6 +114,40 @@ class ModelInstallBase(ABC):
"""
pass
@abstractmethod
def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase:
"""
Download and install the model located at remote site.
This will download the model located at `source`,
probe it, and install it into the models directory.
This call is executed asynchronously in a separate
thread, and the returned object is a
invokeai.backend.model_manager.download.DownloadJobBase
object which can be interrogated to get the status of
the download and install process. Call our `wait_for_downloads()`
method to wait for all downloads to complete.
:param source: Either a URL or a HuggingFace repo_id.
:returns queue: DownloadQueueBase object.
"""
pass
@abstractmethod
def wait_for_downloads(self) -> Dict[str, str]:
"""
Wait for all pending downloads to complete.
This will block until all pending downloads have
completed, been cancelled, or errored out. It will
block indefinitely if one or more jobs are in the
paused state.
It will return a dict that maps the source model
URL or repo_id to the ID of the installed model.
"""
pass
@abstractmethod
def unregister(self, id: str):
"""
@ -167,6 +215,9 @@ class ModelInstall(ModelInstallBase):
_config: InvokeAIAppConfig
_logger: InvokeAILogger
_store: ModelConfigStore
_download_queue: DownloadQueueBase
_async_installs: Dict[str, str]
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
_legacy_configs = {
BaseModelType.StableDiffusion1: {
@ -196,12 +247,14 @@ class ModelInstall(ModelInstallBase):
store: Optional[ModelConfigStore] = None,
config: Optional[InvokeAIAppConfig] = None,
logger: Optional[InvokeAILogger] = None,
download: Optional[DownloadQueueBase] = None,
): # noqa D107 - use base class docstrings
self._config = config or InvokeAIAppConfig.get_config()
self._logger = logger or InvokeAILogger.getLogger()
if store is None:
store = ModelConfigStoreYAML(config.model_conf_path)
self._store = store
self._store = store or ModelConfigStoreYAML(self._config.model_conf_path)
self._download_queue = download or DownloadQueue()
self._async_installs = dict()
self._tmpdir = None
def register(self, model_path: Union[Path, str]) -> str: # noqa D102
model_path = Path(model_path)
@ -232,7 +285,8 @@ class ModelInstall(ModelInstallBase):
def install(self, model_path: Union[Path, str]) -> str: # noqa D102
model_path = Path(model_path)
info: ModelProbeInfo = ModelProbe.probe(model_path)
dest_path = self._config.models_path / info.base_model.value / info.model_type.value / model_path.name
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
dest_path.parent.mkdir(parents=True, exist_ok=True)
# if path already exists then we jigger the name to make it unique
counter: int = 1
@ -240,7 +294,7 @@ class ModelInstall(ModelInstallBase):
dest_path = dest_path.with_stem(dest_path.stem + f"_{counter:02d}")
counter += 1
self._register(
return self._register(
model_path.replace(dest_path),
info,
)
@ -253,6 +307,41 @@ class ModelInstall(ModelInstallBase):
rmtree(model.path)
self.unregister(id)
def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase: # noqa D102
# choose a temporary directory inside the models directory
models_dir = self._config.models_path
queue = self._download_queue
self._async_installs[source] = None
def complete_installation(job: DownloadJobBase):
self._logger.info(f"{job.source}: {job.status} filename={job.destination}({job.bytes}/{job.total_bytes})")
if job.status == "completed":
id = self.install(job.destination)
info = self._store.get_model(id)
info.description = f"Downloaded model {info.name}"
info.source_url = str(job.source)
self._store.update_model(id, info)
self._async_installs[job.source] = id
jobs = queue.list_jobs()
if len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
self._tmpdir = None
# note - this is probably not going to work. The tmpdir
# will be deleted before the job actually runs.
# Better to do the cleanup in the callback
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
return queue.create_download_job(
source=source,
destdir=self._tmpdir.name,
event_handlers=[complete_installation]
)
def wait_for_downloads(self) -> Dict[str, str]: # noqa D102
self._download_queue.join()
id_map = self._async_installs
self._async_installs = dict()
return id_map
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
callback = self._scan_install if install else self._scan_register
search = ModelSearch(on_model_found=callback)