""" Utility (backend) functions used by model_install.py """ import os import re import shutil import warnings from dataclasses import dataclass, field from pathlib import Path from tempfile import TemporaryDirectory from typing import Callable, Dict, List, Optional, Set, Union import requests import torch from diffusers import DiffusionPipeline from diffusers import logging as dlogging from huggingface_hub import HfApi, HfFolder, hf_hub_url from omegaconf import OmegaConf from tqdm import tqdm import invokeai.configs as configs from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_management import AddModelResult, BaseModelType, ModelManager, ModelType, ModelVariantType from invokeai.backend.model_management.model_probe import ModelProbe, ModelProbeInfo, SchedulerPredictionType from invokeai.backend.util import download_with_resume from invokeai.backend.util.devices import choose_torch_device, torch_dtype from ..util.logging import InvokeAILogger warnings.filterwarnings("ignore") # --------------------------globals----------------------- config = InvokeAIAppConfig.get_config() logger = InvokeAILogger.get_logger(name="InvokeAI") # the initial "configs" dir is now bundled in the `invokeai.configs` package Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" Config_preamble = """ # This file describes the alternative machine learning models # available to InvokeAI script. # # To add a new model, follow the examples below. Each # model requires a model config file, a weights file, # and the width and height of the images it # was trained on. """ LEGACY_CONFIGS = { BaseModelType.StableDiffusion1: { ModelVariantType.Normal: { SchedulerPredictionType.Epsilon: "v1-inference.yaml", SchedulerPredictionType.VPrediction: "v1-inference-v.yaml", }, ModelVariantType.Inpaint: { SchedulerPredictionType.Epsilon: "v1-inpainting-inference.yaml", SchedulerPredictionType.VPrediction: "v1-inpainting-inference-v.yaml", }, }, BaseModelType.StableDiffusion2: { ModelVariantType.Normal: { SchedulerPredictionType.Epsilon: "v2-inference.yaml", SchedulerPredictionType.VPrediction: "v2-inference-v.yaml", }, ModelVariantType.Inpaint: { SchedulerPredictionType.Epsilon: "v2-inpainting-inference.yaml", SchedulerPredictionType.VPrediction: "v2-inpainting-inference-v.yaml", }, }, BaseModelType.StableDiffusionXL: { ModelVariantType.Normal: "sd_xl_base.yaml", }, BaseModelType.StableDiffusionXLRefiner: { ModelVariantType.Normal: "sd_xl_refiner.yaml", }, } @dataclass class InstallSelections: install_models: List[str] = field(default_factory=list) remove_models: List[str] = field(default_factory=list) @dataclass class ModelLoadInfo: name: str model_type: ModelType base_type: BaseModelType path: Optional[Path] = None repo_id: Optional[str] = None subfolder: Optional[str] = None description: str = "" installed: bool = False recommended: bool = False default: bool = False requires: Optional[List[str]] = field(default_factory=list) class ModelInstall(object): def __init__( self, config: InvokeAIAppConfig, prediction_type_helper: Optional[Callable[[Path], SchedulerPredictionType]] = None, model_manager: Optional[ModelManager] = None, access_token: Optional[str] = None, ): self.config = config self.mgr = model_manager or ModelManager(config.model_conf_path) self.datasets = OmegaConf.load(Dataset_path) self.prediction_helper = prediction_type_helper self.access_token = access_token or HfFolder.get_token() self.reverse_paths = self._reverse_paths(self.datasets) def all_models(self) -> Dict[str, ModelLoadInfo]: """ Return dict of model_key=>ModelLoadInfo objects. This method consolidates and simplifies the entries in both models.yaml and INITIAL_MODELS.yaml so that they can be treated uniformly. It also sorts the models alphabetically by their name, to improve the display somewhat. """ model_dict = dict() # first populate with the entries in INITIAL_MODELS.yaml for key, value in self.datasets.items(): name, base, model_type = ModelManager.parse_key(key) value["name"] = name value["base_type"] = base value["model_type"] = model_type model_info = ModelLoadInfo(**value) if model_info.subfolder and model_info.repo_id: model_info.repo_id += f":{model_info.subfolder}" model_dict[key] = model_info # supplement with entries in models.yaml installed_models = [x for x in self.mgr.list_models()] for md in installed_models: base = md["base_model"] model_type = md["model_type"] name = md["model_name"] key = ModelManager.create_key(name, base, model_type) if key in model_dict: model_dict[key].installed = True else: model_dict[key] = ModelLoadInfo( name=name, base_type=base, model_type=model_type, path=value.get("path"), installed=True, ) return {x: model_dict[x] for x in sorted(model_dict.keys(), key=lambda y: model_dict[y].name.lower())} def _is_autoloaded(self, model_info: dict) -> bool: path = model_info.get("path") if not path: return False for autodir in ["autoimport_dir", "lora_dir", "embedding_dir", "controlnet_dir"]: if autodir_path := getattr(self.config, autodir): autodir_path = self.config.root_path / autodir_path if Path(path).is_relative_to(autodir_path): return True return False def list_models(self, model_type): installed = self.mgr.list_models(model_type=model_type) print() print(f"Installed models of type `{model_type}`:") print(f"{'Model Key':50} Model Path") for i in installed: print(f"{'/'.join([i['base_model'],i['model_type'],i['model_name']]):50} {i['path']}") print() # logic here a little reversed to maintain backward compatibility def starter_models(self, all_models: bool = False) -> Set[str]: models = set() for key, value in self.datasets.items(): name, base, model_type = ModelManager.parse_key(key) if all_models or model_type in [ModelType.Main, ModelType.Vae]: models.add(key) return models def recommended_models(self) -> Set[str]: starters = self.starter_models(all_models=True) return set([x for x in starters if self.datasets[x].get("recommended", False)]) def default_model(self) -> str: starters = self.starter_models() defaults = [x for x in starters if self.datasets[x].get("default", False)] return defaults[0] def install(self, selections: InstallSelections): verbosity = dlogging.get_verbosity() # quench NSFW nags dlogging.set_verbosity_error() job = 1 jobs = len(selections.remove_models) + len(selections.install_models) # remove requested models for key in selections.remove_models: name, base, mtype = self.mgr.parse_key(key) logger.info(f"Deleting {mtype} model {name} [{job}/{jobs}]") try: self.mgr.del_model(name, base, mtype) except FileNotFoundError as e: logger.warning(e) job += 1 # add requested models self._remove_installed(selections.install_models) self._add_required_models(selections.install_models) for path in selections.install_models: logger.info(f"Installing {path} [{job}/{jobs}]") try: self.heuristic_import(path) except (ValueError, KeyError) as e: logger.error(str(e)) job += 1 dlogging.set_verbosity(verbosity) self.mgr.commit() def heuristic_import( self, model_path_id_or_url: Union[str, Path], models_installed: Set[Path] = None, ) -> Dict[str, AddModelResult]: """ :param model_path_id_or_url: A Path to a local model to import, or a string representing its repo_id or URL :param models_installed: Set of installed models, used for recursive invocation Returns a set of dict objects corresponding to newly-created stanzas in models.yaml. """ if not models_installed: models_installed = dict() model_path_id_or_url = str(model_path_id_or_url).strip("\"' ") # A little hack to allow nested routines to retrieve info on the requested ID self.current_id = model_path_id_or_url path = Path(model_path_id_or_url) # fix relative paths if path.exists() and not path.is_absolute(): path = path.absolute() # make relative to current WD # checkpoint file, or similar if path.is_file(): models_installed.update({str(path): self._install_path(path)}) # folders style or similar elif path.is_dir() and any( [ (path / x).exists() for x in {"config.json", "model_index.json", "learned_embeds.bin", "pytorch_lora_weights.bin"} ] ): models_installed.update({str(model_path_id_or_url): self._install_path(path)}) # recursive scan elif path.is_dir(): for child in path.iterdir(): self.heuristic_import(child, models_installed=models_installed) # huggingface repo elif len(str(model_path_id_or_url).split("/")) == 2: models_installed.update({str(model_path_id_or_url): self._install_repo(str(model_path_id_or_url))}) # a URL elif str(model_path_id_or_url).startswith(("http:", "https:", "ftp:")): models_installed.update({str(model_path_id_or_url): self._install_url(model_path_id_or_url)}) else: raise KeyError(f"{str(model_path_id_or_url)} is not recognized as a local path, repo ID or URL. Skipping") return models_installed def _remove_installed(self, model_list: List[str]): all_models = self.all_models() for path in model_list: key = self.reverse_paths.get(path) if key and all_models[key].installed: logger.warning(f"{path} already installed. Skipping.") model_list.remove(path) def _add_required_models(self, model_list: List[str]): additional_models = [] all_models = self.all_models() for path in model_list: if not (key := self.reverse_paths.get(path)): continue for requirement in all_models[key].requires: requirement_key = self.reverse_paths.get(requirement) if not all_models[requirement_key].installed: additional_models.append(requirement) model_list.extend(additional_models) # install a model from a local path. The optional info parameter is there to prevent # the model from being probed twice in the event that it has already been probed. def _install_path(self, path: Path, info: ModelProbeInfo = None) -> AddModelResult: info = info or ModelProbe().heuristic_probe(path, self.prediction_helper) if not info: logger.warning(f"Unable to parse format of {path}") return None model_name = path.stem if path.is_file() else path.name if self.mgr.model_exists(model_name, info.base_type, info.model_type): raise ValueError(f'A model named "{model_name}" is already installed.') attributes = self._make_attributes(path, info) return self.mgr.add_model( model_name=model_name, base_model=info.base_type, model_type=info.model_type, model_attributes=attributes, ) def _install_url(self, url: str) -> AddModelResult: with TemporaryDirectory(dir=self.config.models_path) as staging: location = download_with_resume(url, Path(staging)) if not location: logger.error(f"Unable to download {url}. Skipping.") info = ModelProbe().heuristic_probe(location, self.prediction_helper) dest = self.config.models_path / info.base_type.value / info.model_type.value / location.name dest.parent.mkdir(parents=True, exist_ok=True) models_path = shutil.move(location, dest) # staged version will be garbage-collected at this time return self._install_path(Path(models_path), info) def _install_repo(self, repo_id: str) -> AddModelResult: # hack to recover models stored in subfolders -- # Required to get the "v2" model of monster-labs/control_v1p_sd15_qrcode_monster subfolder = None if match := re.match(r"^([^/]+/[^/]+):(\w+)$", repo_id): repo_id = match.group(1) subfolder = match.group(2) hinfo = HfApi().model_info(repo_id) # we try to figure out how to download this most economically # list all the files in the repo files = [x.rfilename for x in hinfo.siblings] if subfolder: files = [x for x in files if x.startswith(f"{subfolder}/")] prefix = f"{subfolder}/" if subfolder else "" location = None with TemporaryDirectory(dir=self.config.models_path) as staging: staging = Path(staging) if f"{prefix}model_index.json" in files: location = self._download_hf_pipeline(repo_id, staging, subfolder=subfolder) # pipeline elif f"{prefix}unet/model.onnx" in files: location = self._download_hf_model(repo_id, files, staging) else: for suffix in ["safetensors", "bin"]: if f"{prefix}pytorch_lora_weights.{suffix}" in files: location = self._download_hf_model( repo_id, ["pytorch_lora_weights.bin"], staging, subfolder=subfolder ) # LoRA break elif ( self.config.precision == "float16" and f"{prefix}diffusion_pytorch_model.fp16.{suffix}" in files ): # vae, controlnet or some other standalone files = ["config.json", f"diffusion_pytorch_model.fp16.{suffix}"] location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) break elif f"{prefix}diffusion_pytorch_model.{suffix}" in files: files = ["config.json", f"diffusion_pytorch_model.{suffix}"] location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) break elif f"{prefix}learned_embeds.{suffix}" in files: location = self._download_hf_model( repo_id, [f"learned_embeds.{suffix}"], staging, subfolder=subfolder ) break elif ( f"{prefix}image_encoder.txt" in files and f"{prefix}ip_adapter.{suffix}" in files ): # IP-Adapter files = ["image_encoder.txt", f"ip_adapter.{suffix}"] location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) break elif f"{prefix}model.{suffix}" in files and f"{prefix}config.json" in files: # This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted # by InvokeAI for use with IP-Adapters. files = ["config.json", f"model.{suffix}"] location = self._download_hf_model(repo_id, files, staging, subfolder=subfolder) break if not location: logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.") return {} info = ModelProbe().heuristic_probe(location, self.prediction_helper) if not info: logger.warning(f"Could not probe {location}. Skipping install.") return {} dest = ( self.config.models_path / info.base_type.value / info.model_type.value / self._get_model_name(repo_id, location) ) if dest.exists(): shutil.rmtree(dest) shutil.copytree(location, dest) return self._install_path(dest, info) def _get_model_name(self, path_name: str, location: Path) -> str: """ Calculate a name for the model - primitive implementation. """ if key := self.reverse_paths.get(path_name): (name, base, mtype) = ModelManager.parse_key(key) return name elif location.is_dir(): return location.name else: return location.stem def _make_attributes(self, path: Path, info: ModelProbeInfo) -> dict: model_name = path.name if path.is_dir() else path.stem description = f"{info.base_type.value} {info.model_type.value} model {model_name}" if key := self.reverse_paths.get(self.current_id): if key in self.datasets: description = self.datasets[key].get("description") or description rel_path = self.relative_to_root(path, self.config.models_path) attributes = dict( path=str(rel_path), description=str(description), model_format=info.format, ) legacy_conf = None if info.model_type == ModelType.Main or info.model_type == ModelType.ONNX: attributes.update( dict( variant=info.variant_type, ) ) if info.format == "checkpoint": try: possible_conf = path.with_suffix(".yaml") if possible_conf.exists(): legacy_conf = str(self.relative_to_root(possible_conf)) elif info.base_type in [BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2]: legacy_conf = Path( self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type][info.prediction_type], ) else: legacy_conf = Path( self.config.legacy_conf_dir, LEGACY_CONFIGS[info.base_type][info.variant_type] ) except KeyError: legacy_conf = Path(self.config.legacy_conf_dir, "v1-inference.yaml") # best guess if info.model_type == ModelType.ControlNet and info.format == "checkpoint": possible_conf = path.with_suffix(".yaml") if possible_conf.exists(): legacy_conf = str(self.relative_to_root(possible_conf)) if legacy_conf: attributes.update(dict(config=str(legacy_conf))) return attributes def relative_to_root(self, path: Path, root: Optional[Path] = None) -> Path: root = root or self.config.root_path if path.is_relative_to(root): return path.relative_to(root) else: return path def _download_hf_pipeline(self, repo_id: str, staging: Path, subfolder: str = None) -> Path: """ Retrieve a StableDiffusion model from cache or remote and then does a save_pretrained() to the indicated staging area. """ _, name = repo_id.split("/") precision = torch_dtype(choose_torch_device()) variants = ["fp16", None] if precision == torch.float16 else [None, "fp16"] model = None for variant in variants: try: model = DiffusionPipeline.from_pretrained( repo_id, variant=variant, torch_dtype=precision, safety_checker=None, subfolder=subfolder, ) except Exception as e: # most errors are due to fp16 not being present. Fix this to catch other errors if "fp16" not in str(e): print(e) if model: break if not model: logger.error(f"Diffusers model {repo_id} could not be downloaded. Skipping.") return None model.save_pretrained(staging / name, safe_serialization=True) return staging / name def _download_hf_model(self, repo_id: str, files: List[str], staging: Path, subfolder: None) -> Path: _, name = repo_id.split("/") location = staging / name paths = list() for filename in files: filePath = Path(filename) p = hf_download_with_resume( repo_id, model_dir=location / filePath.parent, model_name=filePath.name, access_token=self.access_token, subfolder=filePath.parent / subfolder if subfolder else filePath.parent, ) if p: paths.append(p) else: logger.warning(f"Could not download {filename} from {repo_id}.") return location if len(paths) > 0 else None @classmethod def _reverse_paths(cls, datasets) -> dict: """ Reverse mapping from repo_id/path to destination name. """ return {v.get("path") or v.get("repo_id"): k for k, v in datasets.items()} # ------------------------------------- def yes_or_no(prompt: str, default_yes=True): default = "y" if default_yes else "n" response = input(f"{prompt} [{default}] ") or default if default_yes: return response[0] not in ("n", "N") else: return response[0] in ("y", "Y") # --------------------------------------------- def hf_download_from_pretrained(model_class: object, model_name: str, destination: Path, **kwargs): logger = InvokeAILogger.get_logger("InvokeAI") logger.addFilter(lambda x: "fp16 is not a valid" not in x.getMessage()) model = model_class.from_pretrained( model_name, resume_download=True, **kwargs, ) model.save_pretrained(destination, safe_serialization=True) return destination # --------------------------------------------- def hf_download_with_resume( repo_id: str, model_dir: str, model_name: str, model_dest: Path = None, access_token: str = None, subfolder: str = None, ) -> Path: model_dest = model_dest or Path(os.path.join(model_dir, model_name)) os.makedirs(model_dir, exist_ok=True) url = hf_hub_url(repo_id, model_name, subfolder=subfolder) header = {"Authorization": f"Bearer {access_token}"} if access_token else {} open_mode = "wb" exist_size = 0 if os.path.exists(model_dest): exist_size = os.path.getsize(model_dest) header["Range"] = f"bytes={exist_size}-" open_mode = "ab" resp = requests.get(url, headers=header, stream=True) total = int(resp.headers.get("content-length", 0)) if resp.status_code == 416: # "range not satisfiable", which means nothing to return logger.info(f"{model_name}: complete file found. Skipping.") return model_dest elif resp.status_code == 404: logger.warning("File not found") return None elif resp.status_code != 200: logger.warning(f"{model_name}: {resp.reason}") elif exist_size > 0: logger.info(f"{model_name}: partial file found. Resuming...") else: logger.info(f"{model_name}: Downloading...") try: with ( open(model_dest, open_mode) as file, tqdm( desc=model_name, initial=exist_size, total=total + exist_size, unit="iB", unit_scale=True, unit_divisor=1000, ) as bar, ): for data in resp.iter_content(chunk_size=1024): size = file.write(data) bar.update(size) except Exception as e: logger.error(f"An error occurred while downloading {model_name}: {str(e)}") return None return model_dest