""" Utility (backend) functions used by model_install.py """ import os import re import shutil import sys import warnings from pathlib import Path from tempfile import TemporaryFile import requests from diffusers import AutoencoderKL from huggingface_hub import hf_hub_url from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from tqdm import tqdm from typing import List import invokeai.configs as configs from invokeai.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline from ..globals import Globals, global_cache_dir, global_config_dir from invokeai.models import ModelManager warnings.filterwarnings("ignore") # --------------------------globals----------------------- Model_dir = "models" Weights_dir = "ldm/stable-diffusion-v1/" # the initial "configs" dir is now bundled in the `invokeai.configs` package Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml" # initial models omegaconf Datasets = None Config_preamble = """ # This file describes the alternative machine learning models # available to InvokeAI script. # # To add a new model, follow the examples below. Each # model requires a model config file, a weights file, # and the width and height of the images it # was trained on. """ def default_config_file(): return Path(global_config_dir()) / "models.yaml" def sd_configs(): return Path(global_config_dir()) / "stable-diffusion" def initial_models(): global Datasets if Datasets: return Datasets return (Datasets := OmegaConf.load(Dataset_path)) def install_requested_models( install_initial_models: List[str] = None, remove_models: List[str] = None, scan_directory: Path = None, external_models: List[str] = None, scan_at_startup: bool = False, convert_to_diffusers: bool = False, precision: str = "float16", purge_deleted: bool = False, config_file_path: Path = None, ): ''' Entry point for installing/deleting starter models, or installing external models. ''' config_file_path=config_file_path or default_config_file() if not config_file_path.exists(): open(config_file_path,'w') model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision) if remove_models and len(remove_models) > 0: print("== DELETING UNCHECKED STARTER MODELS ==") for model in remove_models: print(f'{model}...') model_manager.del_model(model, delete_files=purge_deleted) model_manager.commit(config_file_path) if install_initial_models and len(install_initial_models) > 0: print("== INSTALLING SELECTED STARTER MODELS ==") successfully_downloaded = download_weight_datasets( models=install_initial_models, access_token=None, precision=precision, ) # FIX: for historical reasons, we don't use model manager here update_config_file(successfully_downloaded, config_file_path) if len(successfully_downloaded) < len(install_initial_models): print("** Some of the model downloads were not successful") # due to above, we have to reload the model manager because conf file # was changed behind its back model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision) external_models = external_models or list() if scan_directory: external_models.append(str(scan_directory)) if len(external_models)>0: print("== INSTALLING EXTERNAL MODELS ==") for path_url_or_repo in external_models: try: model_manager.heuristic_import( path_url_or_repo, convert=convert_to_diffusers, commit_to_conf=config_file_path ) except KeyboardInterrupt: sys.exit(-1) except Exception: pass if scan_at_startup and scan_directory.is_dir(): argument = '--autoconvert' if convert_to_diffusers else '--autoimport' initfile = Path(Globals.root, Globals.initfile) replacement = Path(Globals.root, f'{Globals.initfile}.new') directory = str(scan_directory).replace('\\','/') with open(initfile,'r') as input: with open(replacement,'w') as output: while line := input.readline(): if not line.startswith(argument): output.writelines([line]) output.writelines([f'{argument} {directory}']) os.replace(replacement,initfile) # ------------------------------------- def yes_or_no(prompt: str, default_yes=True): default = "y" if default_yes else "n" response = input(f"{prompt} [{default}] ") or default if default_yes: return response[0] not in ("n", "N") else: return response[0] in ("y", "Y") # ------------------------------------- def get_root(root: str = None) -> str: if root: return root elif os.environ.get("INVOKEAI_ROOT"): return os.environ.get("INVOKEAI_ROOT") else: return Globals.root # --------------------------------------------- def recommended_datasets() -> dict: datasets = dict() for ds in initial_models().keys(): if initial_models()[ds].get("recommended", False): datasets[ds] = True return datasets # --------------------------------------------- def default_dataset() -> dict: datasets = dict() for ds in initial_models().keys(): if initial_models()[ds].get("default", False): datasets[ds] = True return datasets # --------------------------------------------- def all_datasets() -> dict: datasets = dict() for ds in initial_models().keys(): datasets[ds] = True return datasets # --------------------------------------------- # look for legacy model.ckpt in models directory and offer to # normalize its name def migrate_models_ckpt(): model_path = os.path.join(Globals.root, Model_dir, Weights_dir) if not os.path.exists(os.path.join(model_path, "model.ckpt")): return new_name = initial_models()["stable-diffusion-1.4"]["file"] print('The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.') print(f"model.ckpt => {new_name}") os.replace( os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name) ) # --------------------------------------------- def download_weight_datasets( models: List[str], access_token: str, precision: str = "float32" ): migrate_models_ckpt() successful = dict() for mod in models: print(f"Downloading {mod}:") successful[mod] = _download_repo_or_file( initial_models()[mod], access_token, precision=precision ) return successful def _download_repo_or_file( mconfig: DictConfig, access_token: str, precision: str = "float32" ) -> Path: path = None if mconfig["format"] == "ckpt": path = _download_ckpt_weights(mconfig, access_token) else: path = _download_diffusion_weights(mconfig, access_token, precision=precision) if "vae" in mconfig and "repo_id" in mconfig["vae"]: _download_diffusion_weights( mconfig["vae"], access_token, precision=precision ) return path def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path: repo_id = mconfig["repo_id"] filename = mconfig["file"] cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir) return hf_download_with_resume( repo_id=repo_id, model_dir=cache_dir, model_name=filename, access_token=access_token, ) # --------------------------------------------- def download_from_hf( model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs ): path = global_cache_dir(cache_subdir) model = model_class.from_pretrained( model_name, cache_dir=path, resume_download=True, **kwargs, ) model_name = "--".join(("models", *model_name.split("/"))) return path / model_name if model else None def _download_diffusion_weights( mconfig: DictConfig, access_token: str, precision: str = "float32" ): repo_id = mconfig["repo_id"] model_class = ( StableDiffusionGeneratorPipeline if mconfig.get("format", None) == "diffusers" else AutoencoderKL ) extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}] path = None for extra_args in extra_arg_list: try: path = download_from_hf( model_class, repo_id, cache_subdir="diffusers", safety_checker=None, **extra_args, ) except OSError as e: if str(e).startswith("fp16 is not a valid"): pass else: print(f"An unexpected error occurred while downloading the model: {e})") if path: break return path # --------------------------------------------- def hf_download_with_resume( repo_id: str, model_dir: str, model_name: str, access_token: str = None ) -> Path: model_dest = Path(os.path.join(model_dir, model_name)) os.makedirs(model_dir, exist_ok=True) url = hf_hub_url(repo_id, model_name) header = {"Authorization": f"Bearer {access_token}"} if access_token else {} open_mode = "wb" exist_size = 0 if os.path.exists(model_dest): exist_size = os.path.getsize(model_dest) header["Range"] = f"bytes={exist_size}-" open_mode = "ab" resp = requests.get(url, headers=header, stream=True) total = int(resp.headers.get("content-length", 0)) if ( resp.status_code == 416 ): # "range not satisfiable", which means nothing to return print(f"* {model_name}: complete file found. Skipping.") return model_dest elif resp.status_code != 200: print(f"** An error occurred during downloading {model_name}: {resp.reason}") elif exist_size > 0: print(f"* {model_name}: partial file found. Resuming...") else: print(f"* {model_name}: Downloading...") try: if total < 2000: print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}") return None with open(model_dest, open_mode) as file, tqdm( desc=model_name, initial=exist_size, total=total + exist_size, unit="iB", unit_scale=True, unit_divisor=1000, ) as bar: for data in resp.iter_content(chunk_size=1024): size = file.write(data) bar.update(size) except Exception as e: print(f"An error occurred while downloading {model_name}: {str(e)}") return None return model_dest # --------------------------------------------- def update_config_file(successfully_downloaded: dict, config_file: Path): config_file = ( Path(config_file) if config_file is not None else default_config_file() ) # In some cases (incomplete setup, etc), the default configs directory might be missing. # Create it if it doesn't exist. # this check is ignored if opt.config_file is specified - user is assumed to know what they # are doing if they are passing a custom config file from elsewhere. if config_file is default_config_file() and not config_file.parent.exists(): configs_src = Dataset_path.parent configs_dest = default_config_file().parent shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True) yaml = new_config_file_contents(successfully_downloaded, config_file) try: backup = None if os.path.exists(config_file): print( f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig" ) backup = config_file.with_suffix(".yaml.orig") ## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183 if sys.platform == "win32" and backup.is_file(): backup.unlink() config_file.rename(backup) with TemporaryFile() as tmp: tmp.write(Config_preamble.encode()) tmp.write(yaml.encode()) with open(str(config_file.expanduser().resolve()), "wb") as new_config: tmp.seek(0) new_config.write(tmp.read()) except Exception as e: print(f"**Error creating config file {config_file}: {str(e)} **") if backup is not None: print("restoring previous config file") ## workaround, for WinError 183, see above if sys.platform == "win32" and config_file.is_file(): config_file.unlink() backup.rename(config_file) return print(f"Successfully created new configuration file {config_file}") # --------------------------------------------- def new_config_file_contents( successfully_downloaded: dict, config_file: Path, ) -> str: if config_file.exists(): conf = OmegaConf.load(str(config_file.expanduser().resolve())) else: conf = OmegaConf.create() default_selected = None for model in successfully_downloaded: # a bit hacky - what we are doing here is seeing whether a checkpoint # version of the model was previously defined, and whether the current # model is a diffusers (indicated with a path) if conf.get(model) and Path(successfully_downloaded[model]).is_dir(): delete_weights(model, conf[model]) stanza = {} mod = initial_models()[model] stanza["description"] = mod["description"] stanza["repo_id"] = mod["repo_id"] stanza["format"] = mod["format"] # diffusers don't need width and height (probably .ckpt doesn't either) # so we no longer require these in INITIAL_MODELS.yaml if "width" in mod: stanza["width"] = mod["width"] if "height" in mod: stanza["height"] = mod["height"] if "file" in mod: stanza["weights"] = os.path.relpath( successfully_downloaded[model], start=Globals.root ) stanza["config"] = os.path.normpath(os.path.join(sd_configs(), mod["config"])) if "vae" in mod: if "file" in mod["vae"]: stanza["vae"] = os.path.normpath( os.path.join(Model_dir, Weights_dir, mod["vae"]["file"]) ) else: stanza["vae"] = mod["vae"] if mod.get("default", False): stanza["default"] = True default_selected = True conf[model] = stanza # if no default model was chosen, then we select the first # one in the list if not default_selected: conf[list(successfully_downloaded.keys())[0]]["default"] = True return OmegaConf.to_yaml(conf) # --------------------------------------------- def delete_weights(model_name: str, conf_stanza: dict): if not (weights := conf_stanza.get("weights")): return if re.match("/VAE/", conf_stanza.get("config")): return print( f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?" ) weights = Path(weights) if not weights.is_absolute(): weights = Path(Globals.root) / weights try: weights.unlink() except OSError as e: print(str(e))