From 34438ce1af2da2a0c879a83e6fe2780a9d49911c Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Wed, 3 Apr 2024 23:26:48 -0400 Subject: [PATCH] add simplified model manager install API to InvocationContext --- .../app/services/shared/invocation_context.py | 98 ++++++++++++++++++- 1 file changed, 97 insertions(+), 1 deletion(-) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 9994d663e5..176303b055 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -1,9 +1,10 @@ import threading from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union from PIL.Image import Image +from pydantic.networks import AnyHttpUrl from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES @@ -426,6 +427,101 @@ class ModelsInterface(InvocationContextInterface): model_format=format, ) + def install_model( + self, + source: str, + config: Optional[Dict[str, Any]] = None, + access_token: Optional[str] = None, + inplace: Optional[bool] = False, + timeout: Optional[int] = 0, + ) -> str: + """Install and register a model in the database. + + Args: + source: String source; see below + config: Optional dict. Any fields in this dict + will override corresponding autoassigned probe fields in the + model's config record. + access_token: Optional access token for remote sources. + inplace: If true, installs a local model in place rather than copying + it into the models directory + timeout: How long to wait on install (in seconds). A value of 0 (default) + blocks indefinitely + + The source can be: + 1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`) + 2. An http or https URL (`https://foo.bar/foo`) + 3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`) + + We extend the HuggingFace repo_id syntax to include the variant and the + subfolder or path. The following are acceptable alternatives: + stabilityai/stable-diffusion-v4 + stabilityai/stable-diffusion-v4:fp16 + stabilityai/stable-diffusion-v4:fp16:vae + stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors + stabilityai/stable-diffusion-v4:onnx:vae + + Because a local file path can look like a huggingface repo_id, the logic + first checks whether the path exists on disk, and if not, it is treated as + a parseable huggingface repo. + + Returns: + Key to the newly installed model. + + May Raise: + ValueError -- bad source + UnknownModelException -- remote model not found + InvalidModelException -- what was retrieved from remote is not a model + TimeoutError -- model could not be installed within timeout + Exception -- another error condition + """ + installer = self._services.model_manager.install + job = installer.heuristic_import( + source=source, + config=config, + access_token=access_token, + inplace=inplace, + ) + installer.wait_for_job(job, timeout) + if job.errored: + raise Exception(job.error) + key: str = job.config_out.key + return key + + def download_and_cache_model( + self, + source: Union[str, AnyHttpUrl], + access_token: Optional[str] = None, + timeout: Optional[int] = 0, + ) -> Path: + """Download the model file located at source to the models cache and return its Path. + + This can be used to single-file install models and other resources of arbitrary types + which should not get registered with the database. If the model is already + installed, the cached path will be returned. Otherwise it will be downloaded. + + Args: + source: A URL or a string that can be converted in one. Repo_ids + do not work here. + access_token: Optional access token for restricted resources. + timeout: Wait up to the indicated number of seconds before timing + out long downloads. + + Result: + Path of the downloaded model + + May Raise: + HTTPError + TimeoutError + """ + installer = self._services.model_manager.install + path: Path = installer.download_and_cache( + source=source, + access_token=access_token, + timeout=timeout, + ) + return path + class ConfigInterface(InvocationContextInterface): def get(self) -> InvokeAIAppConfig: