mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
This commit enhances support for V2 variant (epsilon and v-predict) import and conversion to diffusers, by prompting the user to select the proper config file during startup time autoimport as well as in the invokeai installer script..
495 lines
17 KiB
Python
495 lines
17 KiB
Python
"""
|
|
Utility (backend) functions used by model_install.py
|
|
"""
|
|
import os
|
|
import re
|
|
import shutil
|
|
import sys
|
|
import warnings
|
|
from pathlib import Path
|
|
from tempfile import TemporaryFile
|
|
|
|
import requests
|
|
from diffusers import AutoencoderKL
|
|
from huggingface_hub import hf_hub_url
|
|
from omegaconf import OmegaConf
|
|
from omegaconf.dictconfig import DictConfig
|
|
from tqdm import tqdm
|
|
from typing import List
|
|
|
|
import invokeai.configs as configs
|
|
from ..generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
|
from ..globals import Globals, global_cache_dir, global_config_dir
|
|
from ..model_manager import ModelManager
|
|
|
|
warnings.filterwarnings("ignore")
|
|
|
|
# --------------------------globals-----------------------
|
|
Model_dir = "models"
|
|
Weights_dir = "ldm/stable-diffusion-v1/"
|
|
|
|
# the initial "configs" dir is now bundled in the `invokeai.configs` package
|
|
Dataset_path = Path(configs.__path__[0]) / "INITIAL_MODELS.yaml"
|
|
|
|
# initial models omegaconf
|
|
Datasets = None
|
|
|
|
Config_preamble = """
|
|
# This file describes the alternative machine learning models
|
|
# available to InvokeAI script.
|
|
#
|
|
# To add a new model, follow the examples below. Each
|
|
# model requires a model config file, a weights file,
|
|
# and the width and height of the images it
|
|
# was trained on.
|
|
"""
|
|
|
|
def default_config_file():
|
|
return Path(global_config_dir()) / "models.yaml"
|
|
|
|
def sd_configs():
|
|
return Path(global_config_dir()) / "stable-diffusion"
|
|
|
|
def initial_models():
|
|
global Datasets
|
|
if Datasets:
|
|
return Datasets
|
|
return (Datasets := OmegaConf.load(Dataset_path))
|
|
|
|
def install_requested_models(
|
|
install_initial_models: List[str] = None,
|
|
remove_models: List[str] = None,
|
|
scan_directory: Path = None,
|
|
external_models: List[str] = None,
|
|
scan_at_startup: bool = False,
|
|
convert_to_diffusers: bool = False,
|
|
precision: str = "float16",
|
|
purge_deleted: bool = False,
|
|
config_file_path: Path = None,
|
|
):
|
|
'''
|
|
Entry point for installing/deleting starter models, or installing external models.
|
|
'''
|
|
config_file_path=config_file_path or default_config_file()
|
|
if not config_file_path.exists():
|
|
open(config_file_path,'w')
|
|
|
|
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
|
|
|
if remove_models and len(remove_models) > 0:
|
|
print("== DELETING UNCHECKED STARTER MODELS ==")
|
|
for model in remove_models:
|
|
print(f'{model}...')
|
|
model_manager.del_model(model, delete_files=purge_deleted)
|
|
model_manager.commit(config_file_path)
|
|
|
|
if install_initial_models and len(install_initial_models) > 0:
|
|
print("== INSTALLING SELECTED STARTER MODELS ==")
|
|
successfully_downloaded = download_weight_datasets(
|
|
models=install_initial_models,
|
|
access_token=None,
|
|
precision=precision,
|
|
) # FIX: for historical reasons, we don't use model manager here
|
|
update_config_file(successfully_downloaded, config_file_path)
|
|
if len(successfully_downloaded) < len(install_initial_models):
|
|
print("** Some of the model downloads were not successful")
|
|
|
|
# due to above, we have to reload the model manager because conf file
|
|
# was changed behind its back
|
|
model_manager= ModelManager(OmegaConf.load(config_file_path),precision=precision)
|
|
|
|
external_models = external_models or list()
|
|
if scan_directory:
|
|
external_models.append(str(scan_directory))
|
|
|
|
if len(external_models)>0:
|
|
print("== INSTALLING EXTERNAL MODELS ==")
|
|
for path_url_or_repo in external_models:
|
|
try:
|
|
model_manager.heuristic_import(
|
|
path_url_or_repo,
|
|
convert=convert_to_diffusers,
|
|
config_file_callback=_pick_configuration_file,
|
|
commit_to_conf=config_file_path
|
|
)
|
|
except KeyboardInterrupt:
|
|
sys.exit(-1)
|
|
except Exception:
|
|
pass
|
|
|
|
if scan_at_startup and scan_directory.is_dir():
|
|
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
|
|
initfile = Path(Globals.root, Globals.initfile)
|
|
replacement = Path(Globals.root, f'{Globals.initfile}.new')
|
|
directory = str(scan_directory).replace('\\','/')
|
|
with open(initfile,'r') as input:
|
|
with open(replacement,'w') as output:
|
|
while line := input.readline():
|
|
if not line.startswith(argument):
|
|
output.writelines([line])
|
|
output.writelines([f'{argument} "{directory}"'])
|
|
os.replace(replacement,initfile)
|
|
|
|
# -------------------------------------
|
|
def yes_or_no(prompt: str, default_yes=True):
|
|
default = "y" if default_yes else "n"
|
|
response = input(f"{prompt} [{default}] ") or default
|
|
if default_yes:
|
|
return response[0] not in ("n", "N")
|
|
else:
|
|
return response[0] in ("y", "Y")
|
|
|
|
# -------------------------------------
|
|
def _pick_configuration_file(checkpoint_path: Path)->Path:
|
|
print(
|
|
"""
|
|
Please select the type of this model:
|
|
[1] A Stable Diffusion v1.x ckpt/safetensors model
|
|
[2] A Stable Diffusion v1.x inpainting ckpt/safetensors model
|
|
[3] A Stable Diffusion v2.x base model (512 pixels; no 'parameterization:' in its yaml file)
|
|
[4] A Stable Diffusion v2.x v-predictive model (768 pixels; look for 'parameterization: "v"' in its yaml file)
|
|
[5] Other (you will be prompted to enter the config file path)
|
|
[Q] I have no idea! Skip the import.
|
|
""")
|
|
choices = [
|
|
global_config_dir() / 'stable-diffusion' / x
|
|
for x in [
|
|
'v1-inference.yaml',
|
|
'v1-inpainting-inference.yaml',
|
|
'v2-inference.yaml',
|
|
'v2-inference-v.yaml',
|
|
]
|
|
]
|
|
|
|
ok = False
|
|
while not ok:
|
|
try:
|
|
choice = input('select 0-5, Q > ').strip()
|
|
if choice.startswith(('q','Q')):
|
|
return
|
|
if choice == '5':
|
|
choice = Path(input('Select config file for this model> ').strip()).absolute()
|
|
ok = choice.exists()
|
|
else:
|
|
choice = choices[int(choice)-1]
|
|
ok = True
|
|
except (ValueError, IndexError):
|
|
print(f'{choice} is not a valid choice')
|
|
except EOFError:
|
|
return
|
|
return choice
|
|
|
|
# -------------------------------------
|
|
def get_root(root: str = None) -> str:
|
|
if root:
|
|
return root
|
|
elif os.environ.get("INVOKEAI_ROOT"):
|
|
return os.environ.get("INVOKEAI_ROOT")
|
|
else:
|
|
return Globals.root
|
|
|
|
|
|
# ---------------------------------------------
|
|
def recommended_datasets() -> dict:
|
|
datasets = dict()
|
|
for ds in initial_models().keys():
|
|
if initial_models()[ds].get("recommended", False):
|
|
datasets[ds] = True
|
|
return datasets
|
|
|
|
|
|
# ---------------------------------------------
|
|
def default_dataset() -> dict:
|
|
datasets = dict()
|
|
for ds in initial_models().keys():
|
|
if initial_models()[ds].get("default", False):
|
|
datasets[ds] = True
|
|
return datasets
|
|
|
|
|
|
# ---------------------------------------------
|
|
def all_datasets() -> dict:
|
|
datasets = dict()
|
|
for ds in initial_models().keys():
|
|
datasets[ds] = True
|
|
return datasets
|
|
|
|
|
|
# ---------------------------------------------
|
|
# look for legacy model.ckpt in models directory and offer to
|
|
# normalize its name
|
|
def migrate_models_ckpt():
|
|
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
|
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
|
return
|
|
new_name = initial_models()["stable-diffusion-1.4"]["file"]
|
|
print('The Stable Diffusion v4.1 "model.ckpt" is already installed. The name will be changed to {new_name} to avoid confusion.')
|
|
print(f"model.ckpt => {new_name}")
|
|
os.replace(
|
|
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
|
)
|
|
|
|
|
|
# ---------------------------------------------
|
|
def download_weight_datasets(
|
|
models: List[str], access_token: str, precision: str = "float32"
|
|
):
|
|
migrate_models_ckpt()
|
|
successful = dict()
|
|
for mod in models:
|
|
print(f"Downloading {mod}:")
|
|
successful[mod] = _download_repo_or_file(
|
|
initial_models()[mod], access_token, precision=precision
|
|
)
|
|
return successful
|
|
|
|
|
|
def _download_repo_or_file(
|
|
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
|
) -> Path:
|
|
path = None
|
|
if mconfig["format"] == "ckpt":
|
|
path = _download_ckpt_weights(mconfig, access_token)
|
|
else:
|
|
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
|
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
|
_download_diffusion_weights(
|
|
mconfig["vae"], access_token, precision=precision
|
|
)
|
|
return path
|
|
|
|
|
|
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
|
repo_id = mconfig["repo_id"]
|
|
filename = mconfig["file"]
|
|
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
|
return hf_download_with_resume(
|
|
repo_id=repo_id,
|
|
model_dir=cache_dir,
|
|
model_name=filename,
|
|
access_token=access_token,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------
|
|
def download_from_hf(
|
|
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
|
):
|
|
path = global_cache_dir(cache_subdir)
|
|
model = model_class.from_pretrained(
|
|
model_name,
|
|
cache_dir=path,
|
|
resume_download=True,
|
|
**kwargs,
|
|
)
|
|
model_name = "--".join(("models", *model_name.split("/")))
|
|
return path / model_name if model else None
|
|
|
|
|
|
def _download_diffusion_weights(
|
|
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
|
):
|
|
repo_id = mconfig["repo_id"]
|
|
model_class = (
|
|
StableDiffusionGeneratorPipeline
|
|
if mconfig.get("format", None) == "diffusers"
|
|
else AutoencoderKL
|
|
)
|
|
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
|
path = None
|
|
for extra_args in extra_arg_list:
|
|
try:
|
|
path = download_from_hf(
|
|
model_class,
|
|
repo_id,
|
|
safety_checker=None,
|
|
**extra_args,
|
|
)
|
|
except OSError as e:
|
|
if str(e).startswith("fp16 is not a valid"):
|
|
pass
|
|
else:
|
|
print(f"An unexpected error occurred while downloading the model: {e})")
|
|
if path:
|
|
break
|
|
return path
|
|
|
|
|
|
# ---------------------------------------------
|
|
def hf_download_with_resume(
|
|
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
|
) -> Path:
|
|
model_dest = Path(os.path.join(model_dir, model_name))
|
|
os.makedirs(model_dir, exist_ok=True)
|
|
|
|
url = hf_hub_url(repo_id, model_name)
|
|
|
|
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
|
open_mode = "wb"
|
|
exist_size = 0
|
|
|
|
if os.path.exists(model_dest):
|
|
exist_size = os.path.getsize(model_dest)
|
|
header["Range"] = f"bytes={exist_size}-"
|
|
open_mode = "ab"
|
|
|
|
resp = requests.get(url, headers=header, stream=True)
|
|
total = int(resp.headers.get("content-length", 0))
|
|
|
|
if (
|
|
resp.status_code == 416
|
|
): # "range not satisfiable", which means nothing to return
|
|
print(f"* {model_name}: complete file found. Skipping.")
|
|
return model_dest
|
|
elif resp.status_code != 200:
|
|
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
|
elif exist_size > 0:
|
|
print(f"* {model_name}: partial file found. Resuming...")
|
|
else:
|
|
print(f"* {model_name}: Downloading...")
|
|
|
|
try:
|
|
if total < 2000:
|
|
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
|
return None
|
|
|
|
with open(model_dest, open_mode) as file, tqdm(
|
|
desc=model_name,
|
|
initial=exist_size,
|
|
total=total + exist_size,
|
|
unit="iB",
|
|
unit_scale=True,
|
|
unit_divisor=1000,
|
|
) as bar:
|
|
for data in resp.iter_content(chunk_size=1024):
|
|
size = file.write(data)
|
|
bar.update(size)
|
|
except Exception as e:
|
|
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
|
return None
|
|
return model_dest
|
|
|
|
|
|
# ---------------------------------------------
|
|
def update_config_file(successfully_downloaded: dict, config_file: Path):
|
|
config_file = (
|
|
Path(config_file) if config_file is not None else default_config_file()
|
|
)
|
|
|
|
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
|
# Create it if it doesn't exist.
|
|
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
|
# are doing if they are passing a custom config file from elsewhere.
|
|
if config_file is default_config_file() and not config_file.parent.exists():
|
|
configs_src = Dataset_path.parent
|
|
configs_dest = default_config_file().parent
|
|
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
|
|
|
yaml = new_config_file_contents(successfully_downloaded, config_file)
|
|
|
|
try:
|
|
backup = None
|
|
if os.path.exists(config_file):
|
|
print(
|
|
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
|
)
|
|
backup = config_file.with_suffix(".yaml.orig")
|
|
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
|
if sys.platform == "win32" and backup.is_file():
|
|
backup.unlink()
|
|
config_file.rename(backup)
|
|
|
|
with TemporaryFile() as tmp:
|
|
tmp.write(Config_preamble.encode())
|
|
tmp.write(yaml.encode())
|
|
|
|
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
|
tmp.seek(0)
|
|
new_config.write(tmp.read())
|
|
|
|
except Exception as e:
|
|
print(f"**Error creating config file {config_file}: {str(e)} **")
|
|
if backup is not None:
|
|
print("restoring previous config file")
|
|
## workaround, for WinError 183, see above
|
|
if sys.platform == "win32" and config_file.is_file():
|
|
config_file.unlink()
|
|
backup.rename(config_file)
|
|
return
|
|
|
|
print(f"Successfully created new configuration file {config_file}")
|
|
|
|
|
|
# ---------------------------------------------
|
|
def new_config_file_contents(
|
|
successfully_downloaded: dict, config_file: Path,
|
|
) -> str:
|
|
if config_file.exists():
|
|
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
|
else:
|
|
conf = OmegaConf.create()
|
|
|
|
default_selected = None
|
|
for model in successfully_downloaded:
|
|
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
|
# version of the model was previously defined, and whether the current
|
|
# model is a diffusers (indicated with a path)
|
|
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
|
delete_weights(model, conf[model])
|
|
|
|
stanza = {}
|
|
mod = initial_models()[model]
|
|
stanza["description"] = mod["description"]
|
|
stanza["repo_id"] = mod["repo_id"]
|
|
stanza["format"] = mod["format"]
|
|
# diffusers don't need width and height (probably .ckpt doesn't either)
|
|
# so we no longer require these in INITIAL_MODELS.yaml
|
|
if "width" in mod:
|
|
stanza["width"] = mod["width"]
|
|
if "height" in mod:
|
|
stanza["height"] = mod["height"]
|
|
if "file" in mod:
|
|
stanza["weights"] = os.path.relpath(
|
|
successfully_downloaded[model], start=Globals.root
|
|
)
|
|
stanza["config"] = os.path.normpath(os.path.join(sd_configs(), mod["config"]))
|
|
if "vae" in mod:
|
|
if "file" in mod["vae"]:
|
|
stanza["vae"] = os.path.normpath(
|
|
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
|
)
|
|
else:
|
|
stanza["vae"] = mod["vae"]
|
|
if mod.get("default", False):
|
|
stanza["default"] = True
|
|
default_selected = True
|
|
|
|
conf[model] = stanza
|
|
|
|
# if no default model was chosen, then we select the first
|
|
# one in the list
|
|
if not default_selected:
|
|
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
|
|
|
return OmegaConf.to_yaml(conf)
|
|
|
|
|
|
# ---------------------------------------------
|
|
def delete_weights(model_name: str, conf_stanza: dict):
|
|
if not (weights := conf_stanza.get("weights")):
|
|
return
|
|
if re.match("/VAE/", conf_stanza.get("config")):
|
|
return
|
|
|
|
print(
|
|
f"\n** The checkpoint version of {model_name} is superseded by the diffusers version. Deleting the original file {weights}?"
|
|
)
|
|
|
|
weights = Path(weights)
|
|
if not weights.is_absolute():
|
|
weights = Path(Globals.root) / weights
|
|
try:
|
|
weights.unlink()
|
|
except OSError as e:
|
|
print(str(e))
|