mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wire together download and install; now need to write install events
This commit is contained in:
@ -159,7 +159,6 @@ class DownloadQueue(DownloadQueueBase):
|
||||
self._lock.acquire()
|
||||
assert isinstance(self._jobs[job.id], DownloadJobBase)
|
||||
self._update_job_status(job, DownloadJobStatus.CANCELLED)
|
||||
# del self._jobs[job.id]
|
||||
except (AssertionError, KeyError) as excp:
|
||||
raise UnknownJobIDException("Unrecognized job") from excp
|
||||
finally:
|
||||
|
@ -7,10 +7,12 @@ Typical usage:
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_manager import ModelInstall
|
||||
from invokeai.backend.model_manager.storage import ModelConfigStoreSQL
|
||||
from invokeai.backend.model_manager.download import DownloadQueue
|
||||
|
||||
config = InvokeAIAppConfig.get_config()
|
||||
store = ModelConfigStoreSQL(config.db_path)
|
||||
installer = ModelInstall(store=store, config=config)
|
||||
download = DownloadQueue()
|
||||
installer = ModelInstall(store=store, config=config, download=download)
|
||||
|
||||
# register config, don't move path
|
||||
id: str = installer.register_model('/path/to/model')
|
||||
@ -18,6 +20,15 @@ Typical usage:
|
||||
# register config, and install model in `models`
|
||||
id: str = installer.install_model('/path/to/model')
|
||||
|
||||
# download some remote models and install them in the background
|
||||
installer.download('stabilityai/stable-diffusion-2-1')
|
||||
installer.download('https://civitai.com/api/download/models/154208')
|
||||
installer.download('runwayml/stable-diffusion-v1-5')
|
||||
|
||||
installed_ids = installer.wait_for_downloads()
|
||||
id1 = installed_ids['stabilityai/stable-diffusion-2-1']
|
||||
id2 = installed_ids['https://civitai.com/api/download/models/154208']
|
||||
|
||||
# unregister, don't delete
|
||||
installer.unregister(id)
|
||||
|
||||
@ -36,14 +47,17 @@ The following exceptions may be raised:
|
||||
DuplicateModelException
|
||||
UnknownModelTypeException
|
||||
"""
|
||||
import tempfile
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from shutil import rmtree
|
||||
from typing import Optional, List, Union
|
||||
from typing import Optional, List, Union, Dict
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
from .search import ModelSearch
|
||||
from .storage import ModelConfigStore, ModelConfigStoreYAML, DuplicateModelException
|
||||
from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase
|
||||
from .hash import FastModelHash
|
||||
from .probe import ModelProbe, ModelProbeInfo, InvalidModelException
|
||||
from .config import (
|
||||
@ -100,6 +114,40 @@ class ModelInstallBase(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase:
|
||||
"""
|
||||
Download and install the model located at remote site.
|
||||
|
||||
This will download the model located at `source`,
|
||||
probe it, and install it into the models directory.
|
||||
This call is executed asynchronously in a separate
|
||||
thread, and the returned object is a
|
||||
invokeai.backend.model_manager.download.DownloadJobBase
|
||||
object which can be interrogated to get the status of
|
||||
the download and install process. Call our `wait_for_downloads()`
|
||||
method to wait for all downloads to complete.
|
||||
|
||||
:param source: Either a URL or a HuggingFace repo_id.
|
||||
:returns queue: DownloadQueueBase object.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait_for_downloads(self) -> Dict[str, str]:
|
||||
"""
|
||||
Wait for all pending downloads to complete.
|
||||
|
||||
This will block until all pending downloads have
|
||||
completed, been cancelled, or errored out. It will
|
||||
block indefinitely if one or more jobs are in the
|
||||
paused state.
|
||||
|
||||
It will return a dict that maps the source model
|
||||
URL or repo_id to the ID of the installed model.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def unregister(self, id: str):
|
||||
"""
|
||||
@ -167,6 +215,9 @@ class ModelInstall(ModelInstallBase):
|
||||
_config: InvokeAIAppConfig
|
||||
_logger: InvokeAILogger
|
||||
_store: ModelConfigStore
|
||||
_download_queue: DownloadQueueBase
|
||||
_async_installs: Dict[str, str]
|
||||
_tmpdir: Optional[tempfile.TemporaryDirectory] # used for downloads
|
||||
|
||||
_legacy_configs = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
@ -196,12 +247,14 @@ class ModelInstall(ModelInstallBase):
|
||||
store: Optional[ModelConfigStore] = None,
|
||||
config: Optional[InvokeAIAppConfig] = None,
|
||||
logger: Optional[InvokeAILogger] = None,
|
||||
download: Optional[DownloadQueueBase] = None,
|
||||
): # noqa D107 - use base class docstrings
|
||||
self._config = config or InvokeAIAppConfig.get_config()
|
||||
self._logger = logger or InvokeAILogger.getLogger()
|
||||
if store is None:
|
||||
store = ModelConfigStoreYAML(config.model_conf_path)
|
||||
self._store = store
|
||||
self._store = store or ModelConfigStoreYAML(self._config.model_conf_path)
|
||||
self._download_queue = download or DownloadQueue()
|
||||
self._async_installs = dict()
|
||||
self._tmpdir = None
|
||||
|
||||
def register(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
@ -232,7 +285,8 @@ class ModelInstall(ModelInstallBase):
|
||||
def install(self, model_path: Union[Path, str]) -> str: # noqa D102
|
||||
model_path = Path(model_path)
|
||||
info: ModelProbeInfo = ModelProbe.probe(model_path)
|
||||
dest_path = self._config.models_path / info.base_model.value / info.model_type.value / model_path.name
|
||||
dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# if path already exists then we jigger the name to make it unique
|
||||
counter: int = 1
|
||||
@ -240,7 +294,7 @@ class ModelInstall(ModelInstallBase):
|
||||
dest_path = dest_path.with_stem(dest_path.stem + f"_{counter:02d}")
|
||||
counter += 1
|
||||
|
||||
self._register(
|
||||
return self._register(
|
||||
model_path.replace(dest_path),
|
||||
info,
|
||||
)
|
||||
@ -253,6 +307,41 @@ class ModelInstall(ModelInstallBase):
|
||||
rmtree(model.path)
|
||||
self.unregister(id)
|
||||
|
||||
def download(self, source: Union[str, AnyHttpUrl]) -> DownloadJobBase: # noqa D102
|
||||
# choose a temporary directory inside the models directory
|
||||
models_dir = self._config.models_path
|
||||
queue = self._download_queue
|
||||
self._async_installs[source] = None
|
||||
|
||||
def complete_installation(job: DownloadJobBase):
|
||||
self._logger.info(f"{job.source}: {job.status} filename={job.destination}({job.bytes}/{job.total_bytes})")
|
||||
if job.status == "completed":
|
||||
id = self.install(job.destination)
|
||||
info = self._store.get_model(id)
|
||||
info.description = f"Downloaded model {info.name}"
|
||||
info.source_url = str(job.source)
|
||||
self._store.update_model(id, info)
|
||||
self._async_installs[job.source] = id
|
||||
jobs = queue.list_jobs()
|
||||
if len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]:
|
||||
self._tmpdir = None
|
||||
|
||||
# note - this is probably not going to work. The tmpdir
|
||||
# will be deleted before the job actually runs.
|
||||
# Better to do the cleanup in the callback
|
||||
self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir)
|
||||
return queue.create_download_job(
|
||||
source=source,
|
||||
destdir=self._tmpdir.name,
|
||||
event_handlers=[complete_installation]
|
||||
)
|
||||
|
||||
def wait_for_downloads(self) -> Dict[str, str]: # noqa D102
|
||||
self._download_queue.join()
|
||||
id_map = self._async_installs
|
||||
self._async_installs = dict()
|
||||
return id_map
|
||||
|
||||
def scan_directory(self, scan_dir: Path, install: bool = False) -> List[str]: # noqa D102
|
||||
callback = self._scan_install if install else self._scan_register
|
||||
search = ModelSearch(on_model_found=callback)
|
||||
|
Reference in New Issue
Block a user