mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add simplified model manager install API to InvocationContext
This commit is contained in:
parent
c2e3c61f28
commit
34438ce1af
@ -1,9 +1,10 @@
|
|||||||
import threading
|
import threading
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
|
||||||
|
|
||||||
from PIL.Image import Image
|
from PIL.Image import Image
|
||||||
|
from pydantic.networks import AnyHttpUrl
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
@ -426,6 +427,101 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
model_format=format,
|
model_format=format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def install_model(
|
||||||
|
self,
|
||||||
|
source: str,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
inplace: Optional[bool] = False,
|
||||||
|
timeout: Optional[int] = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Install and register a model in the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: String source; see below
|
||||||
|
config: Optional dict. Any fields in this dict
|
||||||
|
will override corresponding autoassigned probe fields in the
|
||||||
|
model's config record.
|
||||||
|
access_token: Optional access token for remote sources.
|
||||||
|
inplace: If true, installs a local model in place rather than copying
|
||||||
|
it into the models directory
|
||||||
|
timeout: How long to wait on install (in seconds). A value of 0 (default)
|
||||||
|
blocks indefinitely
|
||||||
|
|
||||||
|
The source can be:
|
||||||
|
1. A local file path in posix() format (`/foo/bar` or `C:\foo\bar`)
|
||||||
|
2. An http or https URL (`https://foo.bar/foo`)
|
||||||
|
3. A HuggingFace repo_id (`foo/bar`, `foo/bar:fp16`, `foo/bar:fp16:vae`)
|
||||||
|
|
||||||
|
We extend the HuggingFace repo_id syntax to include the variant and the
|
||||||
|
subfolder or path. The following are acceptable alternatives:
|
||||||
|
stabilityai/stable-diffusion-v4
|
||||||
|
stabilityai/stable-diffusion-v4:fp16
|
||||||
|
stabilityai/stable-diffusion-v4:fp16:vae
|
||||||
|
stabilityai/stable-diffusion-v4::/checkpoints/sd4.safetensors
|
||||||
|
stabilityai/stable-diffusion-v4:onnx:vae
|
||||||
|
|
||||||
|
Because a local file path can look like a huggingface repo_id, the logic
|
||||||
|
first checks whether the path exists on disk, and if not, it is treated as
|
||||||
|
a parseable huggingface repo.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Key to the newly installed model.
|
||||||
|
|
||||||
|
May Raise:
|
||||||
|
ValueError -- bad source
|
||||||
|
UnknownModelException -- remote model not found
|
||||||
|
InvalidModelException -- what was retrieved from remote is not a model
|
||||||
|
TimeoutError -- model could not be installed within timeout
|
||||||
|
Exception -- another error condition
|
||||||
|
"""
|
||||||
|
installer = self._services.model_manager.install
|
||||||
|
job = installer.heuristic_import(
|
||||||
|
source=source,
|
||||||
|
config=config,
|
||||||
|
access_token=access_token,
|
||||||
|
inplace=inplace,
|
||||||
|
)
|
||||||
|
installer.wait_for_job(job, timeout)
|
||||||
|
if job.errored:
|
||||||
|
raise Exception(job.error)
|
||||||
|
key: str = job.config_out.key
|
||||||
|
return key
|
||||||
|
|
||||||
|
def download_and_cache_model(
|
||||||
|
self,
|
||||||
|
source: Union[str, AnyHttpUrl],
|
||||||
|
access_token: Optional[str] = None,
|
||||||
|
timeout: Optional[int] = 0,
|
||||||
|
) -> Path:
|
||||||
|
"""Download the model file located at source to the models cache and return its Path.
|
||||||
|
|
||||||
|
This can be used to single-file install models and other resources of arbitrary types
|
||||||
|
which should not get registered with the database. If the model is already
|
||||||
|
installed, the cached path will be returned. Otherwise it will be downloaded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
source: A URL or a string that can be converted in one. Repo_ids
|
||||||
|
do not work here.
|
||||||
|
access_token: Optional access token for restricted resources.
|
||||||
|
timeout: Wait up to the indicated number of seconds before timing
|
||||||
|
out long downloads.
|
||||||
|
|
||||||
|
Result:
|
||||||
|
Path of the downloaded model
|
||||||
|
|
||||||
|
May Raise:
|
||||||
|
HTTPError
|
||||||
|
TimeoutError
|
||||||
|
"""
|
||||||
|
installer = self._services.model_manager.install
|
||||||
|
path: Path = installer.download_and_cache(
|
||||||
|
source=source,
|
||||||
|
access_token=access_token,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
class ConfigInterface(InvocationContextInterface):
|
class ConfigInterface(InvocationContextInterface):
|
||||||
def get(self) -> InvokeAIAppConfig:
|
def get(self) -> InvokeAIAppConfig:
|
||||||
|
Loading…
Reference in New Issue
Block a user