add simplified model manager install API to InvocationContext

This commit is contained in:
Lincoln Stein 2024-04-03 23:26:48 -04:00 committed by Lincoln Stein
parent c2e3c61f28
commit 34438ce1af

View File

@ -1,9 +1,10 @@
import threading
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from PIL.Image import Image
from pydantic.networks import AnyHttpUrl
from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES
@ -426,6 +427,101 @@ class ModelsInterface(InvocationContextInterface):
model_format=format,
)
def install_model(
self,
source: str,
config: Optional[Dict[str, Any]] = None,
access_token: Optional[str] = None,
inplace: Optional[bool] = False,
timeout: Optional[int] = 0,
) -> str:
"""Install and register a model in the database.
Args:
source: String source; see below
config: Optional dict. Any fields in this dict
will override corresponding autoassigned probe fields in the
model's config record.
access_token: Optional access token for remote sources.
inplace: If true, installs a local model in place rather than copying
it into the models directory
timeout: How long to wait on install (in seconds). A value of 0 (default)
blocks indefinitely
The source can be:
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
2. An http or https URL (`https://foo.bar/foo`)
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
We extend the HuggingFace repo_id syntax to include the variant and the
subfolder or path. The following are acceptable alternatives:
stabilityai/stable-diffusion-v4
stabilityai/stable-diffusion-v4:fp16
stabilityai/stable-diffusion-v4:fp16:vae
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
stabilityai/stable-diffusion-v4:onnx:vae
Because a local file path can look like a huggingface repo_id, the logic
first checks whether the path exists on disk, and if not, it is treated as
a parseable huggingface repo.
Returns:
Key to the newly installed model.
May Raise:
ValueError -- bad source
UnknownModelException -- remote model not found
InvalidModelException -- what was retrieved from remote is not a model
TimeoutError -- model could not be installed within timeout
Exception -- another error condition
"""
installer = self._services.model_manager.install
job = installer.heuristic_import(
source=source,
config=config,
access_token=access_token,
inplace=inplace,
)
installer.wait_for_job(job, timeout)
if job.errored:
raise Exception(job.error)
key: str = job.config_out.key
return key
def download_and_cache_model(
self,
source: Union[str, AnyHttpUrl],
access_token: Optional[str] = None,
timeout: Optional[int] = 0,
) -> Path:
"""Download the model file located at source to the models cache and return its Path.
This can be used to single-file install models and other resources of arbitrary types
which should not get registered with the database. If the model is already
installed, the cached path will be returned. Otherwise it will be downloaded.
Args:
source: A URL or a string that can be converted in one. Repo_ids
do not work here.
access_token: Optional access token for restricted resources.
timeout: Wait up to the indicated number of seconds before timing
out long downloads.
Result:
Path of the downloaded model
May Raise:
HTTPError
TimeoutError
"""
installer = self._services.model_manager.install
path: Path = installer.download_and_cache(
source=source,
access_token=access_token,
timeout=timeout,
)
return path
class ConfigInterface(InvocationContextInterface):
def get(self) -> InvokeAIAppConfig: