mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add simplified model manager install API to InvocationContext
This commit is contained in:
parent
c2e3c61f28
commit
34438ce1af
@ -1,9 +1,10 @@
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||
|
||||
from PIL.Image import Image
|
||||
from pydantic.networks import AnyHttpUrl
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
@ -426,6 +427,101 @@ class ModelsInterface(InvocationContextInterface):
|
||||
model_format=format,
|
||||
)
|
||||
|
||||
def install_model(
|
||||
self,
|
||||
source: str,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
access_token: Optional[str] = None,
|
||||
inplace: Optional[bool] = False,
|
||||
timeout: Optional[int] = 0,
|
||||
) -> str:
|
||||
"""Install and register a model in the database.
|
||||
|
||||
Args:
|
||||
source: String source; see below
|
||||
config: Optional dict. Any fields in this dict
|
||||
will override corresponding autoassigned probe fields in the
|
||||
model's config record.
|
||||
access_token: Optional access token for remote sources.
|
||||
inplace: If true, installs a local model in place rather than copying
|
||||
it into the models directory
|
||||
timeout: How long to wait on install (in seconds). A value of 0 (default)
|
||||
blocks indefinitely
|
||||
|
||||
The source can be:
|
||||
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
|
||||
2. An http or https URL (`https://foo.bar/foo`)
|
||||
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
|
||||
|
||||
We extend the HuggingFace repo_id syntax to include the variant and the
|
||||
subfolder or path. The following are acceptable alternatives:
|
||||
stabilityai/stable-diffusion-v4
|
||||
stabilityai/stable-diffusion-v4:fp16
|
||||
stabilityai/stable-diffusion-v4:fp16:vae
|
||||
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||
stabilityai/stable-diffusion-v4:onnx:vae
|
||||
|
||||
Because a local file path can look like a huggingface repo_id, the logic
|
||||
first checks whether the path exists on disk, and if not, it is treated as
|
||||
a parseable huggingface repo.
|
||||
|
||||
Returns:
|
||||
Key to the newly installed model.
|
||||
|
||||
May Raise:
|
||||
ValueError -- bad source
|
||||
UnknownModelException -- remote model not found
|
||||
InvalidModelException -- what was retrieved from remote is not a model
|
||||
TimeoutError -- model could not be installed within timeout
|
||||
Exception -- another error condition
|
||||
"""
|
||||
installer = self._services.model_manager.install
|
||||
job = installer.heuristic_import(
|
||||
source=source,
|
||||
config=config,
|
||||
access_token=access_token,
|
||||
inplace=inplace,
|
||||
)
|
||||
installer.wait_for_job(job, timeout)
|
||||
if job.errored:
|
||||
raise Exception(job.error)
|
||||
key: str = job.config_out.key
|
||||
return key
|
||||
|
||||
def download_and_cache_model(
|
||||
self,
|
||||
source: Union[str, AnyHttpUrl],
|
||||
access_token: Optional[str] = None,
|
||||
timeout: Optional[int] = 0,
|
||||
) -> Path:
|
||||
"""Download the model file located at source to the models cache and return its Path.
|
||||
|
||||
This can be used to single-file install models and other resources of arbitrary types
|
||||
which should not get registered with the database. If the model is already
|
||||
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||
|
||||
Args:
|
||||
source: A URL or a string that can be converted in one. Repo_ids
|
||||
do not work here.
|
||||
access_token: Optional access token for restricted resources.
|
||||
timeout: Wait up to the indicated number of seconds before timing
|
||||
out long downloads.
|
||||
|
||||
Result:
|
||||
Path of the downloaded model
|
||||
|
||||
May Raise:
|
||||
HTTPError
|
||||
TimeoutError
|
||||
"""
|
||||
installer = self._services.model_manager.install
|
||||
path: Path = installer.download_and_cache(
|
||||
source=source,
|
||||
access_token=access_token,
|
||||
timeout=timeout,
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
class ConfigInterface(InvocationContextInterface):
|
||||
def get(self) -> InvokeAIAppConfig:
|
||||
|
Loading…
Reference in New Issue
Block a user