2023-02-16 05:34:15 +00:00
|
|
|
"""
|
2023-02-15 06:07:39 +00:00
|
|
|
Utility (backend) functions used by model_install.py
|
2023-02-16 05:34:15 +00:00
|
|
|
"""
|
2023-02-15 06:07:39 +00:00
|
|
|
import os
|
|
|
|
import shutil
|
|
|
|
import warnings
|
2023-07-28 13:46:44 +00:00
|
|
|
from dataclasses import dataclass, field
|
2023-02-15 06:07:39 +00:00
|
|
|
from pathlib import Path
|
2023-06-17 02:54:36 +00:00
|
|
|
from tempfile import TemporaryDirectory
|
2023-07-30 14:25:12 +00:00
|
|
|
from typing import Optional, List, Dict, Callable, Union, Set
|
2023-02-15 06:07:39 +00:00
|
|
|
|
|
|
|
import requests
|
2023-07-19 01:10:33 +00:00
|
|
|
from diffusers import DiffusionPipeline
|
2023-07-01 18:32:58 +00:00
|
|
|
from diffusers import logging as dlogging
|
2023-08-02 02:06:27 +00:00
|
|
|
import torch
|
2023-06-17 02:54:36 +00:00
|
|
|
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
2023-02-15 06:07:39 +00:00
|
|
|
from omegaconf import OmegaConf
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
|
|
import invokeai.configs as configs
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-05-26 00:41:26 +00:00
|
|
|
from invokeai.app.services.config import InvokeAIAppConfig
|
2023-07-03 23:32:54 +00:00
|
|
|
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
2023-06-17 02:54:36 +00:00
|
|
|
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
|
|
|
from invokeai.backend.util import download_with_resume
|
2023-08-02 02:06:27 +00:00
|
|
|
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
|
2023-06-01 04:31:46 +00:00
|
|
|
from ..util.logging import InvokeAILogger
|
2023-05-04 04:43:51 +00:00
|
|
|
|
2023-02-15 06:07:39 +00:00
|
|
|
warnings.filterwarnings("ignore")
|
2023-02-16 05:34:15 +00:00
|
|
|
|
2023-02-15 06:07:39 +00:00
|
|
|
# --------------------------globals-----------------------
|
2023-05-26 00:41:26 +00:00
|
|
|
config = InvokeAIAppConfig.get_config()
|
2023-07-28 13:46:44 +00:00
|
|
|
logger = InvokeAILogger.getLogger(name="InvokeAI")
|
2023-02-15 06:07:39 +00:00
|
|
|
|
|
|
|
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
|
|
|
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
|
|
|
|
2023-02-16 05:34:15 +00:00
|
|
|
Config_preamble = """
|
|
|
|
# This file describes the alternative machine learning models
|
2023-02-15 06:07:39 +00:00
|
|
|
# available to InvokeAI script.
|
|
|
|
#
|
|
|
|
# To add a new model, follow the examples below. Each
|
|
|
|
# model requires a model config file, a weights file,
|
|
|
|
# and the width and height of the images it
|
|
|
|
# was trained on.
|
|
|
|
"""
|
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
LEGACY_CONFIGS = {
|
|
|
|
BaseModelType.StableDiffusion1: {
|
2023-07-28 13:46:44 +00:00
|
|
|
ModelVariantType.Normal: "v1-inference.yaml",
|
|
|
|
ModelVariantType.Inpaint: "v1-inpainting-inference.yaml",
|
2023-06-17 02:54:36 +00:00
|
|
|
},
|
|
|
|
BaseModelType.StableDiffusion2: {
|
|
|
|
ModelVariantType.Normal: {
|
2023-07-28 13:46:44 +00:00
|
|
|
SchedulerPredictionType.Epsilon: "v2-inference.yaml",
|
|
|
|
SchedulerPredictionType.VPrediction: "v2-inference-v.yaml",
|
2023-06-17 02:54:36 +00:00
|
|
|
},
|
|
|
|
ModelVariantType.Inpaint: {
|
2023-07-28 13:46:44 +00:00
|
|
|
SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml",
|
|
|
|
SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml",
|
|
|
|
},
|
2023-07-23 00:12:16 +00:00
|
|
|
},
|
|
|
|
BaseModelType.StableDiffusionXL: {
|
2023-07-28 13:46:44 +00:00
|
|
|
ModelVariantType.Normal: "sd_xl_base.yaml",
|
2023-07-23 00:12:16 +00:00
|
|
|
},
|
|
|
|
BaseModelType.StableDiffusionXLRefiner: {
|
2023-07-28 13:46:44 +00:00
|
|
|
ModelVariantType.Normal: "sd_xl_refiner.yaml",
|
2023-07-23 00:12:16 +00:00
|
|
|
},
|
2023-06-17 02:54:36 +00:00
|
|
|
}
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-01 04:31:46 +00:00
|
|
|
@dataclass
|
|
|
|
class ModelInstallList:
|
2023-07-28 13:46:44 +00:00
|
|
|
"""Class for listing models to be installed/removed"""
|
|
|
|
|
Multiple fixes
1. Model installer works correctly under Windows 11 Terminal
2. Fixed crash when configure script hands control off to installer
3. Kill install subprocess on keyboard interrupt
4. Command-line functionality for --yes configuration and model installation
restored.
5. New command-line features:
- install/delete lists of diffusers, LoRAS, controlnets and textual inversions
using repo ids, paths or URLs.
Help:
```
usage: invokeai-model-install [-h] [--diffusers [DIFFUSERS ...]] [--loras [LORAS ...]] [--controlnets [CONTROLNETS ...]] [--textual-inversions [TEXTUAL_INVERSIONS ...]] [--delete] [--full-precision | --no-full-precision]
[--yes] [--default_only] [--list-models {diffusers,loras,controlnets,tis}] [--config_file CONFIG_FILE] [--root_dir ROOT]
InvokeAI model downloader
options:
-h, --help show this help message and exit
--diffusers [DIFFUSERS ...]
List of URLs or repo_ids of diffusers to install/delete
--loras [LORAS ...] List of URLs or repo_ids of LoRA/LyCORIS models to install/delete
--controlnets [CONTROLNETS ...]
List of URLs or repo_ids of controlnet models to install/delete
--textual-inversions [TEXTUAL_INVERSIONS ...]
List of URLs or repo_ids of textual inversion embeddings to install/delete
--delete Delete models listed on command line rather than installing them
--full-precision, --no-full-precision
use 32-bit weights instead of faster 16-bit weights (default: False)
--yes, -y answer "yes" to all prompts
--default_only only install the default model
--list-models {diffusers,loras,controlnets,tis}
list installed models
--config_file CONFIG_FILE, -c CONFIG_FILE
path to configuration file to create
--root_dir ROOT path to root of install directory
```
2023-06-06 01:45:35 +00:00
|
|
|
install_models: List[str] = field(default_factory=list)
|
|
|
|
remove_models: List[str] = field(default_factory=list)
|
2023-03-03 06:02:00 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-03 03:19:14 +00:00
|
|
|
@dataclass
|
2023-07-28 13:46:44 +00:00
|
|
|
class InstallSelections:
|
|
|
|
install_models: List[str] = field(default_factory=list)
|
|
|
|
remove_models: List[str] = field(default_factory=list)
|
|
|
|
|
2023-06-16 03:32:33 +00:00
|
|
|
|
|
|
|
@dataclass
|
2023-07-28 13:46:44 +00:00
|
|
|
class ModelLoadInfo:
|
2023-06-16 03:32:33 +00:00
|
|
|
name: str
|
|
|
|
model_type: ModelType
|
|
|
|
base_type: BaseModelType
|
2023-07-30 14:25:12 +00:00
|
|
|
path: Optional[Path] = None
|
|
|
|
repo_id: Optional[str] = None
|
2023-07-28 13:46:44 +00:00
|
|
|
description: str = ""
|
2023-06-16 03:32:33 +00:00
|
|
|
installed: bool = False
|
|
|
|
recommended: bool = False
|
2023-06-17 02:54:36 +00:00
|
|
|
default: bool = False
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-16 03:32:33 +00:00
|
|
|
class ModelInstall(object):
|
2023-07-28 13:46:44 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
config: InvokeAIAppConfig,
|
2023-08-01 07:55:13 +00:00
|
|
|
prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None,
|
|
|
|
model_manager: Optional[ModelManager] = None,
|
|
|
|
access_token: Optional[str] = None,
|
2023-07-28 13:46:44 +00:00
|
|
|
):
|
2023-06-16 03:32:33 +00:00
|
|
|
self.config = config
|
2023-06-23 20:35:39 +00:00
|
|
|
self.mgr = model_manager or ModelManager(config.model_conf_path)
|
2023-06-16 03:32:33 +00:00
|
|
|
self.datasets = OmegaConf.load(Dataset_path)
|
2023-06-17 02:54:36 +00:00
|
|
|
self.prediction_helper = prediction_type_helper
|
|
|
|
self.access_token = access_token or HfFolder.get_token()
|
|
|
|
self.reverse_paths = self._reverse_paths(self.datasets)
|
2023-06-16 03:32:33 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def all_models(self) -> Dict[str, ModelLoadInfo]:
|
|
|
|
"""
|
2023-06-17 02:54:36 +00:00
|
|
|
Return dict of model_key=>ModelLoadInfo objects.
|
|
|
|
This method consolidates and simplifies the entries in both
|
|
|
|
models.yaml and INITIAL_MODELS.yaml so that they can
|
|
|
|
be treated uniformly. It also sorts the models alphabetically
|
|
|
|
by their name, to improve the display somewhat.
|
2023-07-28 13:46:44 +00:00
|
|
|
"""
|
2023-06-16 03:32:33 +00:00
|
|
|
model_dict = dict()
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-16 03:32:33 +00:00
|
|
|
# first populate with the entries in INITIAL_MODELS.yaml
|
|
|
|
for key, value in self.datasets.items():
|
2023-07-28 13:46:44 +00:00
|
|
|
name, base, model_type = ModelManager.parse_key(key)
|
|
|
|
value["name"] = name
|
|
|
|
value["base_type"] = base
|
|
|
|
value["model_type"] = model_type
|
2023-06-16 03:32:33 +00:00
|
|
|
model_dict[key] = ModelLoadInfo(**value)
|
|
|
|
|
|
|
|
# supplement with entries in models.yaml
|
2023-07-30 12:05:05 +00:00
|
|
|
installed_models = [x for x in self.mgr.list_models()]
|
|
|
|
# suppresses autoloaded models
|
|
|
|
# installed_models = [x for x in self.mgr.list_models() if not self._is_autoloaded(x)]
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-23 18:13:37 +00:00
|
|
|
for md in installed_models:
|
2023-07-28 13:46:44 +00:00
|
|
|
base = md["base_model"]
|
|
|
|
model_type = md["model_type"]
|
|
|
|
name = md["model_name"]
|
2023-06-23 18:13:37 +00:00
|
|
|
key = ModelManager.create_key(name, base, model_type)
|
|
|
|
if key in model_dict:
|
|
|
|
model_dict[key].installed = True
|
|
|
|
else:
|
|
|
|
model_dict[key] = ModelLoadInfo(
|
2023-07-28 13:46:44 +00:00
|
|
|
name=name,
|
|
|
|
base_type=base,
|
|
|
|
model_type=model_type,
|
|
|
|
path=value.get("path"),
|
|
|
|
installed=True,
|
2023-06-23 18:13:37 +00:00
|
|
|
)
|
2023-07-28 13:46:44 +00:00
|
|
|
return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())}
|
2023-06-16 03:32:33 +00:00
|
|
|
|
2023-07-29 17:00:07 +00:00
|
|
|
def _is_autoloaded(self, model_info: dict) -> bool:
|
|
|
|
path = model_info.get("path")
|
|
|
|
if not path:
|
|
|
|
return False
|
2023-07-29 17:00:43 +00:00
|
|
|
for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]:
|
2023-07-29 17:00:07 +00:00
|
|
|
if autodir_path := getattr(self.config, autodir):
|
|
|
|
autodir_path = self.config.root_path / autodir_path
|
|
|
|
if Path(path).is_relative_to(autodir_path):
|
|
|
|
return True
|
|
|
|
return False
|
|
|
|
|
2023-07-14 23:52:47 +00:00
|
|
|
def list_models(self, model_type):
|
|
|
|
installed = self.mgr.list_models(model_type=model_type)
|
2023-07-28 13:46:44 +00:00
|
|
|
print(f"Installed models of type `{model_type}`:")
|
2023-07-14 23:52:47 +00:00
|
|
|
for i in installed:
|
|
|
|
print(f"{i['model_name']}\t{i['base_model']}\t{i['path']}")
|
|
|
|
|
2023-07-26 02:24:03 +00:00
|
|
|
# logic here a little reversed to maintain backward compatibility
|
2023-07-28 13:46:44 +00:00
|
|
|
def starter_models(self, all_models: bool = False) -> Set[str]:
|
2023-06-16 03:32:33 +00:00
|
|
|
models = set()
|
|
|
|
for key, value in self.datasets.items():
|
2023-07-28 13:46:44 +00:00
|
|
|
name, base, model_type = ModelManager.parse_key(key)
|
2023-07-27 03:06:27 +00:00
|
|
|
if all_models or model_type in [ModelType.Main, ModelType.Vae]:
|
2023-07-26 02:24:03 +00:00
|
|
|
models.add(key)
|
2023-06-16 03:32:33 +00:00
|
|
|
return models
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def recommended_models(self) -> Set[str]:
|
2023-07-26 02:24:03 +00:00
|
|
|
starters = self.starter_models(all_models=True)
|
2023-07-28 13:46:44 +00:00
|
|
|
return set([x for x in starters if self.datasets[x].get("recommended", False)])
|
|
|
|
|
|
|
|
def default_model(self) -> str:
|
2023-06-17 02:54:36 +00:00
|
|
|
starters = self.starter_models()
|
2023-07-28 13:46:44 +00:00
|
|
|
defaults = [x for x in starters if self.datasets[x].get("default", False)]
|
2023-06-17 02:54:36 +00:00
|
|
|
return defaults[0]
|
|
|
|
|
|
|
|
def install(self, selections: InstallSelections):
|
2023-07-01 18:32:58 +00:00
|
|
|
verbosity = dlogging.get_verbosity() # quench NSFW nags
|
|
|
|
dlogging.set_verbosity_error()
|
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
job = 1
|
|
|
|
jobs = len(selections.remove_models) + len(selections.install_models)
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
# remove requested models
|
|
|
|
for key in selections.remove_models:
|
2023-07-28 13:46:44 +00:00
|
|
|
name, base, mtype = self.mgr.parse_key(key)
|
|
|
|
logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]")
|
2023-07-01 18:32:58 +00:00
|
|
|
try:
|
2023-07-28 13:46:44 +00:00
|
|
|
self.mgr.del_model(name, base, mtype)
|
2023-07-01 18:32:58 +00:00
|
|
|
except FileNotFoundError as e:
|
|
|
|
logger.warning(e)
|
2023-06-17 02:54:36 +00:00
|
|
|
job += 1
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
# add requested models
|
|
|
|
for path in selections.install_models:
|
2023-07-28 13:46:44 +00:00
|
|
|
logger.info(f"Installing {path} [{job}/{jobs}]")
|
2023-07-04 13:59:11 +00:00
|
|
|
try:
|
|
|
|
self.heuristic_import(path)
|
|
|
|
except (ValueError, KeyError) as e:
|
|
|
|
logger.error(str(e))
|
2023-06-17 02:54:36 +00:00
|
|
|
job += 1
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-07-01 18:32:58 +00:00
|
|
|
dlogging.set_verbosity(verbosity)
|
2023-06-17 02:54:36 +00:00
|
|
|
self.mgr.commit()
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def heuristic_import(
|
|
|
|
self,
|
|
|
|
model_path_id_or_url: Union[str, Path],
|
|
|
|
models_installed: Set[Path] = None,
|
|
|
|
) -> Dict[str, AddModelResult]:
|
|
|
|
"""
|
2023-07-03 23:32:54 +00:00
|
|
|
:param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL
|
|
|
|
:param models_installed: Set of installed models, used for recursive invocation
|
|
|
|
Returns a set of dict objects corresponding to newly-created stanzas in models.yaml.
|
2023-07-28 13:46:44 +00:00
|
|
|
"""
|
2023-06-25 22:50:15 +00:00
|
|
|
|
|
|
|
if not models_installed:
|
2023-07-03 23:32:54 +00:00
|
|
|
models_installed = dict()
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
# A little hack to allow nested routines to retrieve info on the requested ID
|
|
|
|
self.current_id = model_path_id_or_url
|
|
|
|
path = Path(model_path_id_or_url)
|
2023-07-04 13:59:11 +00:00
|
|
|
# checkpoint file, or similar
|
|
|
|
if path.is_file():
|
2023-07-28 13:46:44 +00:00
|
|
|
models_installed.update({str(path): self._install_path(path)})
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-07-04 13:59:11 +00:00
|
|
|
# folders style or similar
|
2023-07-28 13:46:44 +00:00
|
|
|
elif path.is_dir() and any(
|
|
|
|
[
|
|
|
|
(path / x).exists()
|
|
|
|
for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"}
|
|
|
|
]
|
|
|
|
):
|
2023-07-18 02:21:11 +00:00
|
|
|
models_installed.update({str(model_path_id_or_url): self._install_path(path)})
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-07-04 13:59:11 +00:00
|
|
|
# recursive scan
|
|
|
|
elif path.is_dir():
|
|
|
|
for child in path.iterdir():
|
|
|
|
self.heuristic_import(child, models_installed=models_installed)
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-07-04 13:59:11 +00:00
|
|
|
# huggingface repo
|
2023-07-28 13:46:44 +00:00
|
|
|
elif len(str(model_path_id_or_url).split("/")) == 2:
|
2023-07-04 13:59:11 +00:00
|
|
|
models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))})
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-07-04 13:59:11 +00:00
|
|
|
# a URL
|
|
|
|
elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")):
|
|
|
|
models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)})
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-07-04 13:59:11 +00:00
|
|
|
else:
|
2023-07-28 13:46:44 +00:00
|
|
|
raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping")
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-06-25 22:50:15 +00:00
|
|
|
return models_installed
|
2023-06-17 02:54:36 +00:00
|
|
|
|
|
|
|
# install a model from a local path. The optional info parameter is there to prevent
|
|
|
|
# the model from being probed twice in the event that it has already been probed.
|
2023-07-28 13:46:44 +00:00
|
|
|
def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult:
|
|
|
|
info = info or ModelProbe().heuristic_probe(path, self.prediction_helper)
|
2023-07-06 00:25:47 +00:00
|
|
|
if not info:
|
2023-07-28 13:46:44 +00:00
|
|
|
logger.warning(f"Unable to parse format of {path}")
|
2023-07-06 00:25:47 +00:00
|
|
|
return None
|
2023-07-06 01:53:08 +00:00
|
|
|
model_name = path.stem if path.is_file() else path.name
|
2023-07-04 13:59:11 +00:00
|
|
|
if self.mgr.model_exists(model_name, info.base_type, info.model_type):
|
|
|
|
raise ValueError(f'A model named "{model_name}" is already installed.')
|
2023-07-28 13:46:44 +00:00
|
|
|
attributes = self._make_attributes(path, info)
|
|
|
|
return self.mgr.add_model(
|
|
|
|
model_name=model_name,
|
|
|
|
base_model=info.base_type,
|
|
|
|
model_type=info.model_type,
|
|
|
|
model_attributes=attributes,
|
|
|
|
)
|
|
|
|
|
|
|
|
def _install_url(self, url: str) -> AddModelResult:
|
2023-06-17 02:54:36 +00:00
|
|
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
2023-07-28 13:46:44 +00:00
|
|
|
location = download_with_resume(url, Path(staging))
|
2023-06-17 02:54:36 +00:00
|
|
|
if not location:
|
2023-07-28 13:46:44 +00:00
|
|
|
logger.error(f"Unable to download {url}. Skipping.")
|
2023-06-17 02:54:36 +00:00
|
|
|
info = ModelProbe().heuristic_probe(location)
|
|
|
|
dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name
|
2023-07-29 14:30:27 +00:00
|
|
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
2023-07-28 13:46:44 +00:00
|
|
|
models_path = shutil.move(location, dest)
|
2023-06-17 02:54:36 +00:00
|
|
|
|
|
|
|
# staged version will be garbage-collected at this time
|
2023-06-25 22:50:15 +00:00
|
|
|
return self._install_path(Path(models_path), info)
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def _install_repo(self, repo_id: str) -> AddModelResult:
|
2023-06-17 02:54:36 +00:00
|
|
|
hinfo = HfApi().model_info(repo_id)
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
# we try to figure out how to download this most economically
|
|
|
|
# list all the files in the repo
|
|
|
|
files = [x.rfilename for x in hinfo.siblings]
|
2023-06-20 15:08:27 +00:00
|
|
|
location = None
|
2023-06-17 02:54:36 +00:00
|
|
|
|
|
|
|
with TemporaryDirectory(dir=self.config.models_path) as staging:
|
|
|
|
staging = Path(staging)
|
2023-08-02 11:21:21 +00:00
|
|
|
if "model_index.json" in files:
|
2023-07-28 13:46:44 +00:00
|
|
|
location = self._download_hf_pipeline(repo_id, staging) # pipeline
|
2023-07-28 20:54:03 +00:00
|
|
|
elif "unet/model.onnx" in files:
|
|
|
|
location = self._download_hf_model(repo_id, files, staging)
|
2023-06-20 15:08:27 +00:00
|
|
|
else:
|
2023-07-28 13:46:44 +00:00
|
|
|
for suffix in ["safetensors", "bin"]:
|
|
|
|
if f"pytorch_lora_weights.{suffix}" in files:
|
|
|
|
location = self._download_hf_model(repo_id, ["pytorch_lora_weights.bin"], staging) # LoRA
|
2023-06-20 15:08:27 +00:00
|
|
|
break
|
2023-07-28 13:46:44 +00:00
|
|
|
elif (
|
|
|
|
self.config.precision == "float16" and f"diffusion_pytorch_model.fp16.{suffix}" in files
|
|
|
|
): # vae, controlnet or some other standalone
|
|
|
|
files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"]
|
2023-06-20 15:08:27 +00:00
|
|
|
location = self._download_hf_model(repo_id, files, staging)
|
|
|
|
break
|
2023-07-28 13:46:44 +00:00
|
|
|
elif f"diffusion_pytorch_model.{suffix}" in files:
|
|
|
|
files = ["config.json", f"diffusion_pytorch_model.{suffix}"]
|
2023-06-20 15:08:27 +00:00
|
|
|
location = self._download_hf_model(repo_id, files, staging)
|
|
|
|
break
|
2023-07-28 13:46:44 +00:00
|
|
|
elif f"learned_embeds.{suffix}" in files:
|
|
|
|
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
2023-06-20 15:08:27 +00:00
|
|
|
break
|
|
|
|
if not location:
|
2023-07-28 13:46:44 +00:00
|
|
|
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
2023-07-05 13:57:23 +00:00
|
|
|
return {}
|
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
info = ModelProbe().heuristic_probe(location, self.prediction_helper)
|
2023-06-20 15:08:27 +00:00
|
|
|
if not info:
|
2023-07-28 13:46:44 +00:00
|
|
|
logger.warning(f"Could not probe {location}. Skipping install.")
|
2023-07-05 13:57:23 +00:00
|
|
|
return {}
|
2023-07-28 13:46:44 +00:00
|
|
|
dest = (
|
|
|
|
self.config.models_path
|
|
|
|
/ info.base_type.value
|
|
|
|
/ info.model_type.value
|
|
|
|
/ self._get_model_name(repo_id, location)
|
|
|
|
)
|
2023-06-17 02:54:36 +00:00
|
|
|
if dest.exists():
|
|
|
|
shutil.rmtree(dest)
|
2023-07-28 13:46:44 +00:00
|
|
|
shutil.copytree(location, dest)
|
2023-06-25 22:50:15 +00:00
|
|
|
return self._install_path(dest, info)
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def _get_model_name(self, path_name: str, location: Path) -> str:
|
|
|
|
"""
|
2023-06-25 22:50:15 +00:00
|
|
|
Calculate a name for the model - primitive implementation.
|
2023-07-28 13:46:44 +00:00
|
|
|
"""
|
2023-06-25 22:50:15 +00:00
|
|
|
if key := self.reverse_paths.get(path_name):
|
|
|
|
(name, base, mtype) = ModelManager.parse_key(key)
|
|
|
|
return name
|
2023-07-19 01:10:33 +00:00
|
|
|
elif location.is_dir():
|
|
|
|
return location.name
|
2023-06-25 22:50:15 +00:00
|
|
|
else:
|
|
|
|
return location.stem
|
2023-06-17 02:54:36 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict:
|
2023-06-27 16:30:53 +00:00
|
|
|
model_name = path.name if path.is_dir() else path.stem
|
2023-07-28 13:46:44 +00:00
|
|
|
description = f"{info.base_type.value} {info.model_type.value} model {model_name}"
|
2023-06-17 02:54:36 +00:00
|
|
|
if key := self.reverse_paths.get(self.current_id):
|
|
|
|
if key in self.datasets:
|
2023-07-28 13:46:44 +00:00
|
|
|
description = self.datasets[key].get("description") or description
|
2023-06-25 22:50:15 +00:00
|
|
|
|
2023-07-29 14:47:55 +00:00
|
|
|
rel_path = self.relative_to_root(path, self.config.models_path)
|
2023-06-25 22:50:15 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
attributes = dict(
|
2023-07-28 13:46:44 +00:00
|
|
|
path=str(rel_path),
|
|
|
|
description=str(description),
|
|
|
|
model_format=info.format,
|
|
|
|
)
|
2023-07-23 00:12:16 +00:00
|
|
|
legacy_conf = None
|
2023-07-28 20:54:03 +00:00
|
|
|
if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX:
|
2023-07-28 13:46:44 +00:00
|
|
|
attributes.update(
|
|
|
|
dict(
|
|
|
|
variant=info.variant_type,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
if info.format == "checkpoint":
|
2023-06-17 02:54:36 +00:00
|
|
|
try:
|
2023-07-28 13:46:44 +00:00
|
|
|
possible_conf = path.with_suffix(".yaml")
|
2023-06-25 22:50:15 +00:00
|
|
|
if possible_conf.exists():
|
|
|
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
|
|
|
elif info.base_type == BaseModelType.StableDiffusion2:
|
2023-07-28 13:46:44 +00:00
|
|
|
legacy_conf = Path(
|
|
|
|
self.config.legacy_conf_dir,
|
|
|
|
LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type],
|
|
|
|
)
|
2023-06-25 22:50:15 +00:00
|
|
|
else:
|
2023-07-28 13:46:44 +00:00
|
|
|
legacy_conf = Path(
|
|
|
|
self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type]
|
|
|
|
)
|
2023-06-17 02:54:36 +00:00
|
|
|
except KeyError:
|
2023-07-28 13:46:44 +00:00
|
|
|
legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess
|
|
|
|
|
|
|
|
if info.model_type == ModelType.ControlNet and info.format == "checkpoint":
|
|
|
|
possible_conf = path.with_suffix(".yaml")
|
2023-07-23 00:12:16 +00:00
|
|
|
if possible_conf.exists():
|
|
|
|
legacy_conf = str(self.relative_to_root(possible_conf))
|
|
|
|
|
|
|
|
if legacy_conf:
|
2023-07-28 13:46:44 +00:00
|
|
|
attributes.update(dict(config=str(legacy_conf)))
|
2023-06-17 02:54:36 +00:00
|
|
|
return attributes
|
|
|
|
|
2023-07-30 15:07:06 +00:00
|
|
|
def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path:
|
2023-07-29 14:30:27 +00:00
|
|
|
root = root or self.config.root_path
|
2023-06-25 22:50:15 +00:00
|
|
|
if path.is_relative_to(root):
|
|
|
|
return path.relative_to(root)
|
|
|
|
else:
|
|
|
|
return path
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def _download_hf_pipeline(self, repo_id: str, staging: Path) -> Path:
|
|
|
|
"""
|
2023-06-17 02:54:36 +00:00
|
|
|
This retrieves a StableDiffusion model from cache or remote and then
|
|
|
|
does a save_pretrained() to the indicated staging area.
|
2023-07-28 13:46:44 +00:00
|
|
|
"""
|
|
|
|
_, name = repo_id.split("/")
|
2023-08-02 02:06:27 +00:00
|
|
|
precision = torch_dtype(choose_torch_device())
|
2023-08-03 23:24:23 +00:00
|
|
|
variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"]
|
2023-08-03 23:23:52 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
model = None
|
2023-08-03 23:23:52 +00:00
|
|
|
for variant in variants:
|
2023-06-17 02:54:36 +00:00
|
|
|
try:
|
2023-08-02 02:06:27 +00:00
|
|
|
model = DiffusionPipeline.from_pretrained(
|
2023-08-03 23:23:52 +00:00
|
|
|
repo_id,
|
|
|
|
variant=variant,
|
|
|
|
torch_dtype=precision,
|
2023-08-03 23:24:23 +00:00
|
|
|
safety_checker=None,
|
2023-08-02 02:06:27 +00:00
|
|
|
)
|
|
|
|
except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors
|
|
|
|
if "fp16" not in str(e):
|
|
|
|
print(e)
|
2023-08-03 23:23:52 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
if model:
|
|
|
|
break
|
2023-08-03 23:24:23 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
if not model:
|
2023-07-28 13:46:44 +00:00
|
|
|
logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.")
|
2023-06-17 02:54:36 +00:00
|
|
|
return None
|
|
|
|
model.save_pretrained(staging / name, safe_serialization=True)
|
|
|
|
return staging / name
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
def _download_hf_model(self, repo_id: str, files: List[str], staging: Path) -> Path:
|
|
|
|
_, name = repo_id.split("/")
|
2023-06-17 02:54:36 +00:00
|
|
|
location = staging / name
|
|
|
|
paths = list()
|
|
|
|
for filename in files:
|
2023-07-28 20:54:03 +00:00
|
|
|
filePath = Path(filename)
|
2023-07-28 13:46:44 +00:00
|
|
|
p = hf_download_with_resume(
|
2023-07-31 20:47:48 +00:00
|
|
|
repo_id,
|
|
|
|
model_dir=location / filePath.parent,
|
|
|
|
model_name=filePath.name,
|
|
|
|
access_token=self.access_token,
|
|
|
|
subfolder=filePath.parent,
|
2023-07-28 13:46:44 +00:00
|
|
|
)
|
2023-06-17 02:54:36 +00:00
|
|
|
if p:
|
|
|
|
paths.append(p)
|
|
|
|
else:
|
2023-07-28 13:46:44 +00:00
|
|
|
logger.warning(f"Could not download {filename} from {repo_id}.")
|
|
|
|
|
|
|
|
return location if len(paths) > 0 else None
|
2023-02-16 08:22:25 +00:00
|
|
|
|
2023-06-17 02:54:36 +00:00
|
|
|
@classmethod
|
2023-07-28 13:46:44 +00:00
|
|
|
def _reverse_paths(cls, datasets) -> dict:
|
|
|
|
"""
|
2023-06-17 02:54:36 +00:00
|
|
|
Reverse mapping from repo_id/path to destination name.
|
2023-07-28 13:46:44 +00:00
|
|
|
"""
|
|
|
|
return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()}
|
|
|
|
|
2023-05-30 17:49:43 +00:00
|
|
|
|
2023-02-15 06:07:39 +00:00
|
|
|
# -------------------------------------
|
|
|
|
def yes_or_no(prompt: str, default_yes=True):
|
|
|
|
default = "y" if default_yes else "n"
|
|
|
|
response = input(f"{prompt} [{default}] ") or default
|
|
|
|
if default_yes:
|
|
|
|
return response[0] not in ("n", "N")
|
|
|
|
else:
|
|
|
|
return response[0] in ("y", "Y")
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-02-15 06:07:39 +00:00
|
|
|
# ---------------------------------------------
|
2023-07-28 13:46:44 +00:00
|
|
|
def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs):
|
|
|
|
logger = InvokeAILogger.getLogger("InvokeAI")
|
|
|
|
logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage())
|
|
|
|
|
2023-02-15 06:07:39 +00:00
|
|
|
model = model_class.from_pretrained(
|
|
|
|
model_name,
|
|
|
|
resume_download=True,
|
|
|
|
**kwargs,
|
|
|
|
)
|
2023-06-16 03:32:33 +00:00
|
|
|
model.save_pretrained(destination, safe_serialization=True)
|
|
|
|
return destination
|
2023-02-15 06:07:39 +00:00
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
|
2023-02-15 06:07:39 +00:00
|
|
|
# ---------------------------------------------
|
|
|
|
def hf_download_with_resume(
|
2023-07-28 13:46:44 +00:00
|
|
|
repo_id: str,
|
|
|
|
model_dir: str,
|
|
|
|
model_name: str,
|
|
|
|
model_dest: Path = None,
|
|
|
|
access_token: str = None,
|
2023-07-28 20:54:03 +00:00
|
|
|
subfolder: str = None,
|
2023-02-15 06:07:39 +00:00
|
|
|
) -> Path:
|
2023-05-30 17:49:43 +00:00
|
|
|
model_dest = model_dest or Path(os.path.join(model_dir, model_name))
|
2023-02-15 06:07:39 +00:00
|
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
|
2023-07-28 20:54:03 +00:00
|
|
|
url = hf_hub_url(repo_id, model_name, subfolder=subfolder)
|
2023-02-15 06:07:39 +00:00
|
|
|
|
|
|
|
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
|
|
|
open_mode = "wb"
|
|
|
|
exist_size = 0
|
|
|
|
|
|
|
|
if os.path.exists(model_dest):
|
|
|
|
exist_size = os.path.getsize(model_dest)
|
|
|
|
header["Range"] = f"bytes={exist_size}-"
|
|
|
|
open_mode = "ab"
|
|
|
|
|
|
|
|
resp = requests.get(url, headers=header, stream=True)
|
|
|
|
total = int(resp.headers.get("content-length", 0))
|
|
|
|
|
2023-07-28 13:46:44 +00:00
|
|
|
if resp.status_code == 416: # "range not satisfiable", which means nothing to return
|
2023-06-01 04:31:46 +00:00
|
|
|
logger.info(f"{model_name}: complete file found. Skipping.")
|
2023-02-15 06:07:39 +00:00
|
|
|
return model_dest
|
2023-05-30 17:49:43 +00:00
|
|
|
elif resp.status_code == 404:
|
2023-06-01 04:31:46 +00:00
|
|
|
logger.warning("File not found")
|
2023-05-30 17:49:43 +00:00
|
|
|
return None
|
2023-02-15 06:07:39 +00:00
|
|
|
elif resp.status_code != 200:
|
2023-06-01 04:31:46 +00:00
|
|
|
logger.warning(f"{model_name}: {resp.reason}")
|
2023-02-15 06:07:39 +00:00
|
|
|
elif exist_size > 0:
|
2023-06-01 04:31:46 +00:00
|
|
|
logger.info(f"{model_name}: partial file found. Resuming...")
|
2023-02-15 06:07:39 +00:00
|
|
|
else:
|
2023-06-01 04:31:46 +00:00
|
|
|
logger.info(f"{model_name}: Downloading...")
|
2023-02-15 06:07:39 +00:00
|
|
|
|
|
|
|
try:
|
|
|
|
with open(model_dest, open_mode) as file, tqdm(
|
|
|
|
desc=model_name,
|
|
|
|
initial=exist_size,
|
|
|
|
total=total + exist_size,
|
|
|
|
unit="iB",
|
|
|
|
unit_scale=True,
|
|
|
|
unit_divisor=1000,
|
|
|
|
) as bar:
|
|
|
|
for data in resp.iter_content(chunk_size=1024):
|
|
|
|
size = file.write(data)
|
|
|
|
bar.update(size)
|
|
|
|
except Exception as e:
|
2023-06-01 04:31:46 +00:00
|
|
|
logger.error(f"An error occurred while downloading {model_name}: {str(e)}")
|
2023-02-15 06:07:39 +00:00
|
|
|
return None
|
|
|
|
return model_dest
|