mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add new console frontend to initial model selection, and other improvements
1. The invokeai-configure script has now been refactored. The work of selecting and downloading initial models at install time is now done by a script named invokeai-initial-models (module name is ldm.invoke.config.initial_model_select) The calling arguments for invokeai-configure have not changed, so nothing should break. After initializing the root directory, the script calls invokeai-initial-models to let the user select the starting models to install. 2. invokeai-initial-models puts up a console GUI with checkboxes to indicate which models to install. It respects the --default_only and --yes arguments so that CI will continue to work. 3. User can now edit the VAE assigned to diffusers models in the CLI. 4. Fixed a bug that caused a crash during model loading when the VAE is set to None, rather than being empty.
This commit is contained in:
parent
3dd7393984
commit
714fff39ba
@ -12,8 +12,9 @@ echo 2. browser-based UI
|
||||
echo 3. run textual inversion training
|
||||
echo 4. merge models (diffusers type only)
|
||||
echo 5. re-run the configure script to download new models
|
||||
echo 6. open the developer console
|
||||
echo 7. command-line help
|
||||
echo 6. download more starter models from HuggingFace
|
||||
echo 7. open the developer console
|
||||
echo 8. command-line help
|
||||
set /P restore="Please enter 1, 2, 3, 4, 5, 6 or 7: [2] "
|
||||
if not defined restore set restore=2
|
||||
IF /I "%restore%" == "1" (
|
||||
@ -32,6 +33,9 @@ IF /I "%restore%" == "1" (
|
||||
echo Running invokeai-configure...
|
||||
python .venv\Scripts\invokeai-configure.exe %*
|
||||
) ELSE IF /I "%restore%" == "6" (
|
||||
echo Running invokeai-initial-models...
|
||||
python .venv\Scripts\invokeai-initial-models.exe %*
|
||||
) ELSE IF /I "%restore%" == "7" (
|
||||
echo Developer Console
|
||||
echo Python command is:
|
||||
where python
|
||||
@ -43,7 +47,7 @@ IF /I "%restore%" == "1" (
|
||||
echo *************************
|
||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||
call cmd /k
|
||||
) ELSE IF /I "%restore%" == "7" (
|
||||
) ELSE IF /I "%restore%" == "8" (
|
||||
echo Displaying command line help...
|
||||
python .venv\Scripts\invokeai.exe --help %*
|
||||
pause
|
||||
|
@ -30,11 +30,12 @@ if [ "$0" != "bash" ]; then
|
||||
echo "2. browser-based UI"
|
||||
echo "3. run textual inversion training"
|
||||
echo "4. merge models (diffusers type only)"
|
||||
echo "5. open the developer console"
|
||||
echo "6. re-run the configure script to download new models"
|
||||
echo "7. command-line help "
|
||||
echo "5. re-run the configure script to fix a broken install"
|
||||
echo "6. download more starter models from HuggingFace"
|
||||
echo "7. open the developer console"
|
||||
echo "8. command-line help "
|
||||
echo ""
|
||||
read -p "Please enter 1, 2, 3, 4, 5, 6 or 7: [2] " yn
|
||||
read -p "Please enter 1, 2, 3, 4, 5, 6, 7 or 8: [2] " yn
|
||||
choice=${yn:='2'}
|
||||
case $choice in
|
||||
1)
|
||||
@ -54,14 +55,17 @@ if [ "$0" != "bash" ]; then
|
||||
exec invokeai-merge --gui $@
|
||||
;;
|
||||
5)
|
||||
exec invokeai-configure --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
6)
|
||||
exec invokeai-initial-models --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
7)
|
||||
echo "Developer Console:"
|
||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||
bash --init-file "$file_name"
|
||||
;;
|
||||
6)
|
||||
exec invokeai-configure --root ${INVOKEAI_ROOT}
|
||||
;;
|
||||
7)
|
||||
8)
|
||||
exec invokeai --help
|
||||
;;
|
||||
*)
|
||||
|
@ -68,7 +68,7 @@ trinart-characters-2_0:
|
||||
width: 512
|
||||
height: 512
|
||||
recommended: False
|
||||
ft-mse-improved-autoencoder-840000:
|
||||
autoencoder-840000:
|
||||
description: StabilityAI improved autoencoder fine-tuned for human faces. Improves legacy .ckpt models (335 MB)
|
||||
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||
format: ckpt
|
||||
|
@ -820,6 +820,18 @@ def edit_model(model_name:str, gen, opt, completer):
|
||||
completer.set_line(info[attribute])
|
||||
info[attribute] = input(f'{attribute}: ') or info[attribute]
|
||||
|
||||
if info['format'] == 'diffusers':
|
||||
vae = info.get('vae',dict(repo_id=None,path=None,subfolder=None))
|
||||
completer.set_line(vae.get('repo_id') or 'stabilityai/sd-vae-ft-mse')
|
||||
vae['repo_id'] = input('External VAE repo_id: ').strip() or None
|
||||
if not vae['repo_id']:
|
||||
completer.set_line(vae.get('path') or '')
|
||||
vae['path'] = input('Path to a local diffusers VAE model (usually none): ').strip() or None
|
||||
completer.set_line(vae.get('subfolder') or '')
|
||||
vae['subfolder'] = input('Name of subfolder containing the VAE model (usually none): ').strip() or None
|
||||
info['vae'] = vae
|
||||
|
||||
|
||||
if new_name != model_name:
|
||||
manager.del_model(model_name)
|
||||
|
||||
|
@ -10,37 +10,36 @@ print("Loading Python libraries...\n")
|
||||
import argparse
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryFile
|
||||
from typing import Union
|
||||
from urllib import request
|
||||
|
||||
import requests
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from getpass_asterisk import getpass_asterisk
|
||||
from huggingface_hub import HfFolder, hf_hub_url
|
||||
from huggingface_hub import HfFolder
|
||||
from huggingface_hub import login as hf_hub_login
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from tqdm import tqdm
|
||||
from transformers import (AutoProcessor, CLIPSegForImageSegmentation,
|
||||
CLIPTextModel, CLIPTokenizer)
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
CLIPSegForImageSegmentation,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
)
|
||||
|
||||
import invokeai.configs as configs
|
||||
from ldm.invoke.devices import choose_precision, choose_torch_device
|
||||
from ldm.invoke.generator.diffusers_pipeline import \
|
||||
StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
|
||||
from ldm.invoke.config.initial_model_select import (
|
||||
download_from_hf,
|
||||
select_and_download_models,
|
||||
yes_or_no,
|
||||
)
|
||||
from ldm.invoke.globals import Globals, global_config_dir
|
||||
from ldm.invoke.readline import generic_completer
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
import torch
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
@ -104,125 +103,6 @@ Have fun!
|
||||
|
||||
print(message)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def yes_or_no(prompt: str, default_yes=True):
|
||||
completer.set_options(["yes", "no"])
|
||||
completer.complete_extensions(None) # turn off path-completion mode
|
||||
default = "y" if default_yes else "n"
|
||||
response = input(f"{prompt} [{default}] ") or default
|
||||
if default_yes:
|
||||
return response[0] not in ("n", "N")
|
||||
else:
|
||||
return response[0] in ("y", "Y")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def user_wants_to_download_weights() -> str:
|
||||
"""
|
||||
Returns one of "skip", "recommended" or "customized"
|
||||
"""
|
||||
print(
|
||||
"""You can download and configure the weights files manually or let this
|
||||
script do it for you. Manual installation is described at:
|
||||
|
||||
https://invoke-ai.github.io/InvokeAI/installation/020_INSTALL_MANUAL/
|
||||
|
||||
You may download the recommended models (about 15GB total), install all models (40 GB!!)
|
||||
select a customized set, or completely skip this step.
|
||||
"""
|
||||
)
|
||||
completer.set_options(["recommended", "customized", "skip"])
|
||||
completer.complete_extensions(None) # turn off path-completion mode
|
||||
selection = None
|
||||
while selection is None:
|
||||
choice = input(
|
||||
"Download <r>ecommended models, <a>ll models, <c>ustomized list, or <s>kip this step? [r]: "
|
||||
)
|
||||
if choice.startswith(("r", "R")) or len(choice) == 0:
|
||||
selection = "recommended"
|
||||
elif choice.startswith(("c", "C")):
|
||||
selection = "customized"
|
||||
elif choice.startswith(("a", "A")):
|
||||
selection = "all"
|
||||
elif choice.startswith(("s", "S")):
|
||||
selection = "skip"
|
||||
return selection
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def select_datasets(action: str):
|
||||
done = False
|
||||
default_datasets = default_dataset()
|
||||
while not done:
|
||||
datasets = dict()
|
||||
counter = 1
|
||||
|
||||
if action == "customized":
|
||||
print(
|
||||
"""
|
||||
Choose the weight file(s) you wish to download. Before downloading you
|
||||
will be given the option to view and change your selections.
|
||||
"""
|
||||
)
|
||||
for ds in Datasets.keys():
|
||||
recommended = Datasets[ds].get("recommended", False)
|
||||
r_str = "(recommended)" if recommended else ""
|
||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {r_str}')
|
||||
if yes_or_no(" Download?", default_yes=recommended):
|
||||
datasets[ds] = True
|
||||
counter += 1
|
||||
else:
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get("recommended", False):
|
||||
datasets[ds] = True
|
||||
counter += 1
|
||||
|
||||
print("The following weight files will be downloaded:")
|
||||
counter = 1
|
||||
for ds in datasets:
|
||||
dflt = "*" if ds in default_datasets else ""
|
||||
print(f" [{counter}] {ds}{dflt}")
|
||||
counter += 1
|
||||
print("* default")
|
||||
ok_to_download = yes_or_no("Ok to download?")
|
||||
if not ok_to_download:
|
||||
if yes_or_no("Change your selection?"):
|
||||
action = "customized"
|
||||
pass
|
||||
else:
|
||||
done = True
|
||||
else:
|
||||
done = True
|
||||
return datasets if ok_to_download else None
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def recommended_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get("recommended", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def default_dataset() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds].get("default", False):
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def all_datasets() -> dict:
|
||||
datasets = dict()
|
||||
for ds in Datasets.keys():
|
||||
datasets[ds] = True
|
||||
return datasets
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def HfLogin(access_token) -> str:
|
||||
"""
|
||||
@ -242,7 +122,7 @@ def HfLogin(access_token) -> str:
|
||||
|
||||
|
||||
# -------------------------------Authenticate against Hugging Face
|
||||
def authenticate(yes_to_all=False):
|
||||
def save_hf_token(yes_to_all=False):
|
||||
print("** LICENSE AGREEMENT FOR WEIGHT FILES **")
|
||||
print("=" * shutil.get_terminal_size()[0])
|
||||
print(
|
||||
@ -356,149 +236,6 @@ You may re-run the configuration script again in the future if you do not wish t
|
||||
return access_token
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
||||
return
|
||||
new_name = Datasets["stable-diffusion-1.4"]["file"]
|
||||
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
||||
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
||||
if rename:
|
||||
print(f"model.ckpt => {new_name}")
|
||||
os.replace(
|
||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_weight_datasets(
|
||||
models: dict, access_token: str, precision: str = "float32"
|
||||
):
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models.keys():
|
||||
print(f"Downloading {mod}:")
|
||||
successful[mod] = _download_repo_or_file(
|
||||
Datasets[mod], access_token, precision=precision
|
||||
)
|
||||
return successful
|
||||
|
||||
|
||||
def _download_repo_or_file(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
) -> Path:
|
||||
path = None
|
||||
if mconfig["format"] == "ckpt":
|
||||
path = _download_ckpt_weights(mconfig, access_token)
|
||||
else:
|
||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
||||
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
||||
_download_diffusion_weights(
|
||||
mconfig["vae"], access_token, precision=precision
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
||||
repo_id = mconfig["repo_id"]
|
||||
filename = mconfig["file"]
|
||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
||||
return hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=cache_dir,
|
||||
model_name=filename,
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
|
||||
def _download_diffusion_weights(
|
||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
||||
):
|
||||
repo_id = mconfig["repo_id"]
|
||||
model_class = (
|
||||
StableDiffusionGeneratorPipeline
|
||||
if mconfig.get("format", None) == "diffusers"
|
||||
else AutoencoderKL
|
||||
)
|
||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
||||
path = None
|
||||
for extra_args in extra_arg_list:
|
||||
try:
|
||||
path = download_from_hf(
|
||||
model_class,
|
||||
repo_id,
|
||||
cache_subdir="diffusers",
|
||||
safety_checker=None,
|
||||
**extra_args,
|
||||
)
|
||||
except OSError as e:
|
||||
if str(e).startswith("fp16 is not a valid"):
|
||||
pass
|
||||
else:
|
||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
||||
if path:
|
||||
break
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def hf_download_with_resume(
|
||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
||||
) -> Path:
|
||||
model_dest = Path(os.path.join(model_dir, model_name))
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
|
||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
||||
open_mode = "wb"
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header["Range"] = f"bytes={exist_size}-"
|
||||
open_mode = "ab"
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get("content-length", 0))
|
||||
|
||||
if (
|
||||
resp.status_code == 416
|
||||
): # "range not satisfiable", which means nothing to return
|
||||
print(f"* {model_name}: complete file found. Skipping.")
|
||||
return model_dest
|
||||
elif resp.status_code != 200:
|
||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
||||
elif exist_size > 0:
|
||||
print(f"* {model_name}: partial file found. Resuming...")
|
||||
else:
|
||||
print(f"* {model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
||||
return None
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
||||
return None
|
||||
return model_dest
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||
try:
|
||||
@ -517,125 +254,6 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
||||
print(f"Error downloading {label} model")
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def update_config_file(successfully_downloaded: dict, opt: dict):
|
||||
config_file = (
|
||||
Path(opt.config_file) if opt.config_file is not None else Default_config_file
|
||||
)
|
||||
|
||||
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
||||
# Create it if it doesn't exist.
|
||||
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
||||
# are doing if they are passing a custom config file from elsewhere.
|
||||
if config_file is Default_config_file and not config_file.parent.exists():
|
||||
configs_src = Dataset_path.parent
|
||||
configs_dest = Default_config_file.parent
|
||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
||||
|
||||
yaml = new_config_file_contents(successfully_downloaded, config_file, opt)
|
||||
|
||||
try:
|
||||
backup = None
|
||||
if os.path.exists(config_file):
|
||||
print(
|
||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
||||
)
|
||||
backup = config_file.with_suffix(".yaml.orig")
|
||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
||||
if sys.platform == "win32" and backup.is_file():
|
||||
backup.unlink()
|
||||
config_file.rename(backup)
|
||||
|
||||
with TemporaryFile() as tmp:
|
||||
tmp.write(Config_preamble.encode())
|
||||
tmp.write(yaml.encode())
|
||||
|
||||
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
||||
tmp.seek(0)
|
||||
new_config.write(tmp.read())
|
||||
|
||||
except Exception as e:
|
||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
||||
if backup is not None:
|
||||
print("restoring previous config file")
|
||||
## workaround, for WinError 183, see above
|
||||
if sys.platform == "win32" and config_file.is_file():
|
||||
config_file.unlink()
|
||||
backup.rename(config_file)
|
||||
return
|
||||
|
||||
print(f"Successfully created new configuration file {config_file}")
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def new_config_file_contents(successfully_downloaded: dict, config_file: Path, opt: dict) -> str:
|
||||
if config_file.exists():
|
||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
||||
else:
|
||||
conf = OmegaConf.create()
|
||||
|
||||
default_selected = None
|
||||
for model in successfully_downloaded:
|
||||
|
||||
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
||||
# version of the model was previously defined, and whether the current
|
||||
# model is a diffusers (indicated with a path)
|
||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
||||
offer_to_delete_weights(model, conf[model], opt.yes_to_all)
|
||||
|
||||
stanza = {}
|
||||
mod = Datasets[model]
|
||||
stanza["description"] = mod["description"]
|
||||
stanza["repo_id"] = mod["repo_id"]
|
||||
stanza["format"] = mod["format"]
|
||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
||||
# so we no longer require these in INITIAL_MODELS.yaml
|
||||
if "width" in mod:
|
||||
stanza["width"] = mod["width"]
|
||||
if "height" in mod:
|
||||
stanza["height"] = mod["height"]
|
||||
if "file" in mod:
|
||||
stanza["weights"] = os.path.relpath(
|
||||
successfully_downloaded[model], start=Globals.root
|
||||
)
|
||||
stanza["config"] = os.path.normpath(os.path.join(SD_Configs, mod["config"]))
|
||||
if "vae" in mod:
|
||||
if "file" in mod["vae"]:
|
||||
stanza["vae"] = os.path.normpath(
|
||||
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
||||
)
|
||||
else:
|
||||
stanza["vae"] = mod["vae"]
|
||||
if mod.get("default", False):
|
||||
stanza["default"] = True
|
||||
default_selected = True
|
||||
|
||||
conf[model] = stanza
|
||||
|
||||
# if no default model was chosen, then we select the first
|
||||
# one in the list
|
||||
if not default_selected:
|
||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
||||
|
||||
return OmegaConf.to_yaml(conf)
|
||||
|
||||
# ---------------------------------------------
|
||||
def offer_to_delete_weights(model_name: str, conf_stanza: dict, yes_to_all: bool):
|
||||
if not (weights := conf_stanza.get('weights')):
|
||||
return
|
||||
if re.match('/VAE/',conf_stanza.get('config')):
|
||||
return
|
||||
if yes_to_all or \
|
||||
yes_or_no(f'\n** The checkpoint version of {model_name} is superseded by the diffusers version. Delete the original file {weights}?', default_yes=False):
|
||||
weights = Path(weights)
|
||||
if not weights.is_absolute():
|
||||
weights = Path(Globals.root) / weights
|
||||
try:
|
||||
weights.unlink()
|
||||
except OSError as e:
|
||||
print(str(e))
|
||||
|
||||
# ---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
@ -652,22 +270,6 @@ def download_bert():
|
||||
print("...success", file=sys.stderr)
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_from_hf(
|
||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
||||
):
|
||||
print("", file=sys.stderr) # to prevent tqdm from overwriting
|
||||
path = global_cache_dir(cache_subdir)
|
||||
model = model_class.from_pretrained(
|
||||
model_name,
|
||||
cache_dir=path,
|
||||
resume_download=True,
|
||||
**kwargs,
|
||||
)
|
||||
model_name = '--'.join(('models',*model_name.split('/')))
|
||||
return path / model_name if model else None
|
||||
|
||||
|
||||
# ---------------------------------------------
|
||||
def download_clip():
|
||||
print("Installing CLIP model (ignore deprecation errors)...", file=sys.stderr)
|
||||
@ -744,8 +346,9 @@ def download_clipseg():
|
||||
def download_safety_checker():
|
||||
print("Installing model for NSFW content detection...", file=sys.stderr)
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
||||
StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from transformers import AutoFeatureExtractor
|
||||
except ModuleNotFoundError:
|
||||
print("Error installing NSFW checker model:")
|
||||
@ -759,52 +362,6 @@ def download_safety_checker():
|
||||
print("...success", file=sys.stderr)
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def download_weights(opt: dict) -> Union[str, None]:
|
||||
precision = (
|
||||
"float32"
|
||||
if opt.full_precision
|
||||
else choose_precision(torch.device(choose_torch_device()))
|
||||
)
|
||||
|
||||
if opt.yes_to_all:
|
||||
models = default_dataset() if opt.default_only else recommended_datasets()
|
||||
access_token = authenticate(opt.yes_to_all)
|
||||
if len(models) > 0:
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models, access_token, precision=precision
|
||||
)
|
||||
update_config_file(successfully_downloaded, opt)
|
||||
return
|
||||
|
||||
else:
|
||||
choice = user_wants_to_download_weights()
|
||||
|
||||
if choice == "recommended":
|
||||
models = recommended_datasets()
|
||||
elif choice == "all":
|
||||
models = all_datasets()
|
||||
elif choice == "customized":
|
||||
models = select_datasets(choice)
|
||||
if models is None and yes_or_no("Quit?", default_yes=False):
|
||||
sys.exit(0)
|
||||
else: # 'skip'
|
||||
return
|
||||
|
||||
access_token = authenticate()
|
||||
if access_token is not None:
|
||||
HfFolder.save_token(access_token)
|
||||
|
||||
print("\n** DOWNLOADING WEIGHTS **")
|
||||
successfully_downloaded = download_weight_datasets(
|
||||
models, access_token, precision=precision
|
||||
)
|
||||
|
||||
update_config_file(successfully_downloaded, opt)
|
||||
if len(successfully_downloaded) < len(models):
|
||||
return "some of the model weights downloads were not successful"
|
||||
|
||||
|
||||
# -------------------------------------
|
||||
def get_root(root: str = None) -> str:
|
||||
if root:
|
||||
@ -951,13 +508,6 @@ class ProgressBar:
|
||||
# -------------------------------------
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||
parser.add_argument(
|
||||
"--interactive",
|
||||
dest="interactive",
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help="run in interactive mode (default) - DEPRECATED",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-sd-weights",
|
||||
dest="skip_sd_weights",
|
||||
@ -1005,26 +555,16 @@ def main():
|
||||
# setting a global here
|
||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||
|
||||
errors = set()
|
||||
|
||||
try:
|
||||
# We check for to see if the runtime directory is correctly initialized.
|
||||
if Globals.root == "" or not os.path.exists(
|
||||
os.path.join(Globals.root, "invokeai.init")
|
||||
):
|
||||
initialize_rootdir(Globals.root, opt.yes_to_all)
|
||||
save_hf_token(opt.yes_to_all)
|
||||
|
||||
# Optimistically try to download all required assets. If any errors occur, add them and proceed anyway.
|
||||
errors = set()
|
||||
|
||||
if not opt.interactive:
|
||||
print(
|
||||
"WARNING: The --(no)-interactive argument is deprecated and will be removed. Use --skip-sd-weights."
|
||||
)
|
||||
opt.skip_sd_weights = True
|
||||
if opt.skip_sd_weights:
|
||||
print("** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
||||
else:
|
||||
print("** DOWNLOADING DIFFUSION WEIGHTS **")
|
||||
errors.add(download_weights(opt))
|
||||
print("\n** DOWNLOADING SUPPORT MODELS **")
|
||||
download_bert()
|
||||
download_clip()
|
||||
@ -1033,6 +573,13 @@ def main():
|
||||
download_codeformer()
|
||||
download_clipseg()
|
||||
download_safety_checker()
|
||||
|
||||
if opt.skip_sd_weights:
|
||||
print("** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
||||
else:
|
||||
print("** DOWNLOADING DIFFUSION WEIGHTS **")
|
||||
errors.add(select_and_download_models(opt))
|
||||
|
||||
postscript(errors=errors)
|
||||
except KeyboardInterrupt:
|
||||
print("\nGoodbye! Come back soon.")
|
||||
|
@ -25,19 +25,20 @@ import torch
|
||||
import transformers
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import logging as dlogging
|
||||
from diffusers.utils.logging import (get_verbosity, set_verbosity,
|
||||
set_verbosity_error)
|
||||
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||
from huggingface_hub import scan_cache_dir
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from ldm.invoke.generator.diffusers_pipeline import \
|
||||
StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
||||
global_models_dir)
|
||||
from ldm.util import (ask_user, download_with_progress_bar,
|
||||
instantiate_from_config)
|
||||
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||
from ldm.invoke.globals import (
|
||||
Globals,
|
||||
global_autoscan_dir,
|
||||
global_cache_dir,
|
||||
global_models_dir,
|
||||
)
|
||||
from ldm.util import ask_user, download_with_progress_bar, instantiate_from_config
|
||||
|
||||
DEFAULT_MAX_MODELS = 2
|
||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||
@ -374,8 +375,9 @@ class ModelManager(object):
|
||||
print(
|
||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||
)
|
||||
from ldm.invoke.ckpt_to_diffuser import \
|
||||
load_pipeline_from_original_stable_diffusion_ckpt
|
||||
from ldm.invoke.ckpt_to_diffuser import (
|
||||
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
if vae_config := self._choose_diffusers_vae(model_name):
|
||||
vae = self._load_vae(vae_config)
|
||||
@ -495,8 +497,8 @@ class ModelManager(object):
|
||||
safety_checker=None, local_files_only=not Globals.internet_available
|
||||
)
|
||||
if "vae" in mconfig and mconfig["vae"] is not None:
|
||||
vae = self._load_vae(mconfig["vae"])
|
||||
pipeline_args.update(vae=vae)
|
||||
if vae := self._load_vae(mconfig["vae"]):
|
||||
pipeline_args.update(vae=vae)
|
||||
if not isinstance(name_or_path, Path):
|
||||
pipeline_args.update(cache_dir=global_cache_dir("diffusers"))
|
||||
if using_fp16:
|
||||
@ -551,7 +553,7 @@ class ModelManager(object):
|
||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||
)
|
||||
|
||||
if "path" in mconfig:
|
||||
if "path" in mconfig and mconfig["path"] is not None:
|
||||
path = Path(mconfig["path"])
|
||||
if not path.is_absolute():
|
||||
path = Path(Globals.root, path).resolve()
|
||||
@ -762,7 +764,7 @@ class ModelManager(object):
|
||||
model_description = model_description or "Optimized version of {model_name}"
|
||||
print(f">> Optimizing {model_name} (30-60s)")
|
||||
try:
|
||||
# By passing the specified VAE too the conversion function, the autoencoder
|
||||
# By passing the specified VAE to the conversion function, the autoencoder
|
||||
# will be built into the model rather than tacked on afterward via the config file
|
||||
vae_model = self._load_vae(vae) if vae else None
|
||||
convert_ckpt_to_diffuser(
|
||||
@ -789,7 +791,9 @@ class ModelManager(object):
|
||||
print(">> Conversion succeeded")
|
||||
except Exception as e:
|
||||
print(f"** Conversion failed: {str(e)}")
|
||||
print("** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)")
|
||||
print(
|
||||
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||
)
|
||||
|
||||
return new_config
|
||||
|
||||
@ -1102,7 +1106,12 @@ class ModelManager(object):
|
||||
|
||||
def _load_vae(self, vae_config) -> AutoencoderKL:
|
||||
vae_args = {}
|
||||
name_or_path = self.model_name_or_path(vae_config)
|
||||
try:
|
||||
name_or_path = self.model_name_or_path(vae_config)
|
||||
except Exception:
|
||||
return None
|
||||
if name_or_path is None:
|
||||
return None
|
||||
using_fp16 = self.precision == "float16"
|
||||
|
||||
vae_args.update(
|
||||
|
@ -108,6 +108,7 @@ dependencies = [
|
||||
"invokeai-configure" = "ldm.invoke.config.invokeai_configure:main"
|
||||
"invokeai-merge" = "ldm.invoke.merge_diffusers:main" # note name munging
|
||||
"invokeai-ti" = "ldm.invoke.training.textual_inversion:main"
|
||||
"invokeai-initial-models" = "ldm.invoke.config.initial_model_select:main"
|
||||
|
||||
[project.urls]
|
||||
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
||||
|
Loading…
Reference in New Issue
Block a user