diff --git a/invokeai/backend/model_manager/download/base.py b/invokeai/backend/model_manager/download/base.py index 56884305b5..1d864c289b 100644 --- a/invokeai/backend/model_manager/download/base.py +++ b/invokeai/backend/model_manager/download/base.py @@ -46,12 +46,9 @@ class DownloadJobBase(BaseModel): """Class to monitor and control a model download request.""" priority: int = Field(default=10, description="Queue priority; lower values are higher priority") - id: int = Field(description="Numeric ID of this job") + id: int = Field(description="Numeric ID of this job", default=-1) # default id is a placeholder source: str = Field(description="URL or repo_id to download") destination: Path = Field(description="Destination of URL on local disk") - model_key: Optional[str] = Field( - description="After model installation, this field will hold its primary key", default=None - ) metadata: Optional[ModelSourceMetadata] = Field(description="Model metadata (source-specific)", default=None) access_token: Optional[str] = Field(description="access token needed to access this resource") status: DownloadJobStatus = Field(default=DownloadJobStatus.IDLE, description="Status of the download") @@ -114,7 +111,7 @@ class DownloadQueueBase(ABC): event_handlers: Optional[List[DownloadEventHandler]] = None, ) -> DownloadJobBase: """ - Create a download job. + Create and submit a download job. :param source: Source of the download - URL, repo_id or Path :param destdir: Directory to download into. @@ -123,7 +120,22 @@ class DownloadQueueBase(ABC): :param start: Immediately start job [True] :param variant: Variant to download, such as "fp16" (repo_ids only). :param event_handlers: Optional callables that will be called whenever job status changes. - :returns job id: The numeric ID of the DownloadJobBase object for this task. + :returns the job: job.id will be a non-negative value after execution + """ + pass + + def submit_download_job( + self, + job: DownloadJobBase, + start: bool = True, + ): + """ + Submit a download job. + + :param job: A DownloadJobBase + :param start: Immediately start job [True] + + After execution, `job.id` will be set to a non-negative value. """ pass diff --git a/invokeai/backend/model_manager/download/queue.py b/invokeai/backend/model_manager/download/queue.py index 7d6e8e551d..35bec195c2 100644 --- a/invokeai/backend/model_manager/download/queue.py +++ b/invokeai/backend/model_manager/download/queue.py @@ -78,7 +78,7 @@ class DownloadQueue(DownloadQueueBase): _queue: PriorityQueue _lock: threading.Lock _logger: InvokeAILogger - _event_handlers: Optional[List[DownloadEventHandler]] + _event_handlers: List[DownloadEventHandler] = Field(default_factory=list) _next_job_id: int = 0 _sequence: int = 0 # This is for debugging and used to tag jobs in dequeueing order _requests: requests.sessions.Session @@ -90,7 +90,7 @@ class DownloadQueue(DownloadQueueBase): def __init__( self, max_parallel_dl: int = 5, - event_handlers: Optional[List[DownloadEventHandler]] = None, + event_handlers: List[DownloadEventHandler] = [], requests_session: Optional[requests.sessions.Session] = None, config: Optional[InvokeAIAppConfig] = None, ): @@ -139,19 +139,30 @@ class DownloadQueue(DownloadQueueBase): else: raise NotImplementedError(f"Don't know what to do with this type of source: {source}") + job = cls( + source=source, + destination=Path(destdir) / (filename or "."), + access_token=access_token, + event_handlers=(event_handlers or self._event_handlers), + **kwargs, + ) + + return self.submit_download_job(job, start) + + def submit_download_job( + self, + job: DownloadJobBase, + start: bool = True, + ): + """Submit a job.""" + # add the queue's handlers + for handler in self._event_handlers: + job.add_event_handler(handler) try: self._lock.acquire() - id = self._next_job_id - self._jobs[id] = cls( - id=id, - source=source, - destination=Path(destdir) / (filename or "."), - access_token=access_token, - event_handlers=(event_handlers or self._event_handlers), - **kwargs, - ) + job.id = self._next_job_id + self._jobs[job.id] = job self._next_job_id += 1 - job = self._jobs[id] finally: self._lock.release() if start: diff --git a/invokeai/backend/model_manager/install.py b/invokeai/backend/model_manager/install.py index 34ed60dcd0..aff8ceb084 100644 --- a/invokeai/backend/model_manager/install.py +++ b/invokeai/backend/model_manager/install.py @@ -20,7 +20,7 @@ Typical usage: # register config, and install model in `models` id: str = installer.install_path('/path/to/model') -1 # download some remote models and install them in the background + # download some remote models and install them in the background installer.install('stabilityai/stable-diffusion-2-1') installer.install('https://civitai.com/api/download/models/154208') installer.install('runwayml/stable-diffusion-v1-5') @@ -48,6 +48,7 @@ The following exceptions may be raised: DuplicateModelException UnknownModelTypeException """ +import re import tempfile from abc import ABC, abstractmethod from pathlib import Path @@ -60,6 +61,7 @@ from invokeai.backend.util.logging import InvokeAILogger from .search import ModelSearch from .storage import ModelConfigStore, DuplicateModelException, get_config_store from .download import DownloadQueueBase, DownloadQueue, DownloadJobBase, ModelSourceMetadata +from .download.queue import DownloadJobURL, DownloadJobRepoID, DownloadJobPath from .hash import FastModelHash from .probe import ModelProbe, ModelProbeInfo, InvalidModelException from .config import ( @@ -71,6 +73,30 @@ from .config import ( ) +class ModelInstallJob(DownloadJobBase): + """This is a version of DownloadJobBase that has an additional slot for the model key and probe info.""" + + model_key: Optional[str] = Field( + description="After model installation, this field will hold its primary key", default=None + ) + probe_info: Optional[ModelProbeInfo] = Field( + description="If provided, information here will be used instead of probing the model.", + default=None, + ) + + +class ModelInstallURLJob(DownloadJobURL, ModelInstallJob): + """Job for installing URLs.""" + + +class ModelInstallRepoIDJob(DownloadJobRepoID, ModelInstallJob): + """Job for installing repo ids.""" + + +class ModelInstallPathJob(DownloadJobPath, ModelInstallJob): + """Job for installing local paths.""" + + class ModelInstallBase(ABC): """Abstract base class for InvokeAI model installation""" @@ -103,17 +129,18 @@ class ModelInstallBase(ABC): pass @abstractmethod - def register_path(self, model_path: Union[Path, str]) -> str: + def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: """ Probe and register the model at model_path. :param model_path: Filesystem Path to the model. + :param info: Optional ModelProbeInfo object. If not provided, model will be probed. :returns id: The string ID of the registered model. """ pass @abstractmethod - def install_path(self, model_path: Union[Path, str]) -> str: + def install_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: """ Probe, register and install the model in the models directory. @@ -121,13 +148,18 @@ class ModelInstallBase(ABC): the models directory handled by InvokeAI. :param model_path: Filesystem Path to the model. + :param info: Optional ModelProbeInfo object. If not provided, model will be probed. :returns id: The string ID of the installed model. """ pass @abstractmethod def install( - self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None + self, + source: Union[str, Path, AnyHttpUrl], + inplace: bool = True, + variant: Optional[str] = None, + info: Optional[ModelProbeInfo] = None, ) -> DownloadJobBase: """ Download and install the indicated model. @@ -146,6 +178,7 @@ class ModelInstallBase(ABC): the models directory, but registered in place (the default). :param variant: For HuggingFace models, this optional parameter specifies which variant to download (e.g. 'fp16') + :param info: Optional ModelProbeInfo object. If not provided, model will be probed. :returns DownloadQueueBase object. The `inplace` flag does not affect the behavior of downloaded @@ -283,9 +316,9 @@ class ModelInstall(ModelInstallBase): """Return the queue.""" return self._download_queue - def register_path(self, model_path: Union[Path, str]) -> str: # noqa D102 + def register_path(self, model_path: Union[Path, str], info: Optional[ModelProbeInfo] = None) -> str: # noqa D102 model_path = Path(model_path) - info: ModelProbeInfo = ModelProbe.probe(model_path) + info: ModelProbeInfo = info or ModelProbe.probe(model_path) return self._register(model_path, info) def _register(self, model_path: Path, info: ModelProbeInfo) -> str: @@ -317,9 +350,13 @@ class ModelInstall(ModelInstallBase): self._store.add_model(key, registration_data) return key - def install_path(self, model_path: Union[Path, str]) -> str: # noqa D102 + def install_path( + self, + model_path: Union[Path, str], + info: Optional[ModelProbeInfo] = None, + ) -> str: # noqa D102 model_path = Path(model_path) - info: ModelProbeInfo = ModelProbe.probe(model_path) + info: ModelProbeInfo = info or ModelProbe.probe(model_path) dest_path = self._config.models_path / info.base_type.value / info.model_type.value / model_path.name dest_path.parent.mkdir(parents=True, exist_ok=True) @@ -343,66 +380,94 @@ class ModelInstall(ModelInstallBase): self.unregister(key) def install( - self, source: Union[str, Path, AnyHttpUrl], inplace: bool = True, variant: Optional[str] = None + self, + source: Union[str, Path, AnyHttpUrl], + info: Optional[ModelProbeInfo] = None, + inplace: bool = True, + variant: Optional[str] = None, + access_token: Optional[str] = None, ) -> DownloadJobBase: # noqa D102 - # choose a temporary directory inside the models directory - models_dir = self._config.models_path queue = self._download_queue - def complete_installation(job: DownloadJobBase): - if job.status == "completed": - self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.") - model_id = self.install_path(job.destination) - info = self._store.get_model(model_id) - info.source = str(job.source) - metadata: ModelSourceMetadata = job.metadata - info.description = metadata.description or f"Imported model {info.name}" - info.author = metadata.author - info.tags = metadata.tags - info.license = metadata.license - info.thumbnail_url = metadata.thumbnail_url - self._store.update_model(model_id, info) - self._async_installs[job.source] = model_id - elif job.status == "error": - self._logger.warning(f"{job.source}: Model installation error: {job.error}") - elif job.status == "cancelled": - self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.") - jobs = queue.list_jobs() - if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]: - self._tmpdir.cleanup() - self._tmpdir = None + job = self._make_download_job(source, variant, access_token) + handler = ( + self._complete_registration_handler + if inplace and Path(source).exists() + else self._complete_installation_handler + ) + job.add_event_handler(handler) + job.probe_info = info - def complete_registration(job: DownloadJobBase): - if job.status == "completed": - self._logger.info(f"{job.source}: Installing in place.") - model_id = self.register_path(job.destination) - info = self._store.get_model(model_id) - info.source = str(job.source) - info.description = f"Imported model {info.name}" - self._store.update_model(model_id, info) - self._async_installs[job.source] = model_id - job.model_key = model_id - elif job.status == "error": - self._logger.warning(f"{job.source}: Model installation error: {job.error}") - elif job.status == "cancelled": - self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.") + self._async_installs[source] = None + queue.submit_download_job(job, True) + return job + def _complete_installation_handler(self, job: DownloadJobBase): + if job.status == "completed": + self._logger.info(f"{job.source}: Download finished with status {job.status}. Installing.") + model_id = self.install_path(job.destination, job.probe_info) + info = self._store.get_model(model_id) + info.source = str(job.source) + metadata: ModelSourceMetadata = job.metadata + info.description = metadata.description or f"Imported model {info.name}" + info.author = metadata.author + info.tags = metadata.tags + info.license = metadata.license + info.thumbnail_url = metadata.thumbnail_url + self._store.update_model(model_id, info) + self._async_installs[job.source] = model_id + job.model_key = model_id + elif job.status == "error": + self._logger.warning(f"{job.source}: Model installation error: {job.error}") + elif job.status == "cancelled": + self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.") + jobs = self._download_queue.list_jobs() + if self._tmpdir and len(jobs) <= 1 and job.status in ["completed", "error", "cancelled"]: + self._tmpdir.cleanup() + self._tmpdir = None + + def _complete_registration_handler(self, job: DownloadJobBase): + if job.status == "completed": + self._logger.info(f"{job.source}: Installing in place.") + model_id = self.register_path(job.destination, job.probe_info) + info = self._store.get_model(model_id) + info.source = str(job.source) + info.description = f"Imported model {info.name}" + self._store.update_model(model_id, info) + self._async_installs[job.source] = model_id + job.model_key = model_id + elif job.status == "error": + self._logger.warning(f"{job.source}: Model installation error: {job.error}") + elif job.status == "cancelled": + self._logger.warning(f"{job.source}: Model installation cancelled at caller's request.") + + def _make_download_job( + self, + source: Union[str, Path, AnyHttpUrl], + variant: Optional[str] = None, + access_token: Optional[str] = None, + ) -> DownloadJobBase: # In the event that we are being asked to install a path that is already on disk, # we simply probe and register/install it. The job does not actually do anything, but we # create one anyway in order to have similar behavior for local files, URLs and repo_ids. if Path(source).exists(): # a path that is already on disk source = Path(source) destdir = source - job = queue.create_download_job(source=source, destdir=destdir, start=False, variant=variant) - job.add_event_handler(complete_registration if inplace else complete_installation) - else: - self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) - job = queue.create_download_job(source=source, destdir=self._tmpdir.name, start=False, variant=variant) - job.add_event_handler(complete_installation) + return ModelInstallPathJob(source=source, destination=Path(destdir)) - self._async_installs[source] = None - queue.start_job(job) - return job + # choose a temporary directory inside the models directory + models_dir = self._config.models_path + self._tmpdir = self._tmpdir or tempfile.TemporaryDirectory(dir=models_dir) + + if re.match(r"^[\w-]+/[\w-]+$", str(source)): + cls = ModelInstallRepoIDJob + kwargs = dict(variant=variant) + elif re.match(r"^https?://", str(source)): + cls = ModelInstallURLJob + kwargs = {} + else: + raise NotImplementedError(f"Don't know what to do with this type of source: {source}") + return cls(source=source, destination=Path(self._tmpdir.name), access_token=access_token, **kwargs) def wait_for_installs(self) -> Dict[str, str]: # noqa D102 self._download_queue.join() diff --git a/invokeai/backend/model_manager/models/base.py b/invokeai/backend/model_manager/models/base.py index 12312969b1..01fdc93718 100644 --- a/invokeai/backend/model_manager/models/base.py +++ b/invokeai/backend/model_manager/models/base.py @@ -20,7 +20,6 @@ from onnxruntime import ( SessionOptions, get_available_providers, ) -from pydantic import BaseModel from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union from ..config import ( # noqa F401 BaseModelType, diff --git a/invokeai/backend/model_manager/models/stable_diffusion.py b/invokeai/backend/model_manager/models/stable_diffusion.py index 622a707d07..429db2fb87 100644 --- a/invokeai/backend/model_manager/models/stable_diffusion.py +++ b/invokeai/backend/model_manager/models/stable_diffusion.py @@ -14,6 +14,7 @@ from .base import ( read_checkpoint_meta, classproperty, InvalidModelException, + ModelNotFoundException, ) from ..config import SilenceWarnings from .sdxl import StableDiffusionXLModel diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index f9a96c5210..4d45b27621 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -37,11 +37,11 @@ class ModelProbeInfo(object): model_type: ModelType base_type: BaseModelType - variant_type: ModelVariantType - prediction_type: SchedulerPredictionType - upcast_attention: bool format: ModelFormat - image_size: int + variant_type: ModelVariantType = "normal" + prediction_type: SchedulerPredictionType = "v_prediction" + upcast_attention: bool = False + image_size: int = None class ModelProbeBase(ABC): diff --git a/invokeai/backend/util/__init__.py b/invokeai/backend/util/__init__.py index f7b5d36d77..d87fa2fdf1 100644 --- a/invokeai/backend/util/__init__.py +++ b/invokeai/backend/util/__init__.py @@ -18,4 +18,4 @@ from .util import ( # noqa: F401 Chdir, ) from .attention import auto_detect_slice_size # noqa: F401 -from .logging import InvokeAILogger +from .logging import InvokeAILogger # noqa: F401