mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add new console frontend to initial model selection, and other improvements
1. The invokeai-configure script has now been refactored. The work of selecting and downloading initial models at install time is now done by a script named invokeai-initial-models (module name is ldm.invoke.config.initial_model_select) The calling arguments for invokeai-configure have not changed, so nothing should break. After initializing the root directory, the script calls invokeai-initial-models to let the user select the starting models to install. 2. invokeai-initial-models puts up a console GUI with checkboxes to indicate which models to install. It respects the --default_only and --yes arguments so that CI will continue to work. 3. User can now edit the VAE assigned to diffusers models in the CLI. 4. Fixed a bug that caused a crash during model loading when the VAE is set to None, rather than being empty.
This commit is contained in:
parent
3dd7393984
commit
714fff39ba
@ -12,8 +12,9 @@ echo 2. browser-based UI
|
|||||||
echo 3. run textual inversion training
|
echo 3. run textual inversion training
|
||||||
echo 4. merge models (diffusers type only)
|
echo 4. merge models (diffusers type only)
|
||||||
echo 5. re-run the configure script to download new models
|
echo 5. re-run the configure script to download new models
|
||||||
echo 6. open the developer console
|
echo 6. download more starter models from HuggingFace
|
||||||
echo 7. command-line help
|
echo 7. open the developer console
|
||||||
|
echo 8. command-line help
|
||||||
set /P restore="Please enter 1, 2, 3, 4, 5, 6 or 7: [2] "
|
set /P restore="Please enter 1, 2, 3, 4, 5, 6 or 7: [2] "
|
||||||
if not defined restore set restore=2
|
if not defined restore set restore=2
|
||||||
IF /I "%restore%" == "1" (
|
IF /I "%restore%" == "1" (
|
||||||
@ -32,6 +33,9 @@ IF /I "%restore%" == "1" (
|
|||||||
echo Running invokeai-configure...
|
echo Running invokeai-configure...
|
||||||
python .venv\Scripts\invokeai-configure.exe %*
|
python .venv\Scripts\invokeai-configure.exe %*
|
||||||
) ELSE IF /I "%restore%" == "6" (
|
) ELSE IF /I "%restore%" == "6" (
|
||||||
|
echo Running invokeai-initial-models...
|
||||||
|
python .venv\Scripts\invokeai-initial-models.exe %*
|
||||||
|
) ELSE IF /I "%restore%" == "7" (
|
||||||
echo Developer Console
|
echo Developer Console
|
||||||
echo Python command is:
|
echo Python command is:
|
||||||
where python
|
where python
|
||||||
@ -43,7 +47,7 @@ IF /I "%restore%" == "1" (
|
|||||||
echo *************************
|
echo *************************
|
||||||
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
echo *** Type `exit` to quit this shell and deactivate the Python virtual environment ***
|
||||||
call cmd /k
|
call cmd /k
|
||||||
) ELSE IF /I "%restore%" == "7" (
|
) ELSE IF /I "%restore%" == "8" (
|
||||||
echo Displaying command line help...
|
echo Displaying command line help...
|
||||||
python .venv\Scripts\invokeai.exe --help %*
|
python .venv\Scripts\invokeai.exe --help %*
|
||||||
pause
|
pause
|
||||||
|
@ -30,11 +30,12 @@ if [ "$0" != "bash" ]; then
|
|||||||
echo "2. browser-based UI"
|
echo "2. browser-based UI"
|
||||||
echo "3. run textual inversion training"
|
echo "3. run textual inversion training"
|
||||||
echo "4. merge models (diffusers type only)"
|
echo "4. merge models (diffusers type only)"
|
||||||
echo "5. open the developer console"
|
echo "5. re-run the configure script to fix a broken install"
|
||||||
echo "6. re-run the configure script to download new models"
|
echo "6. download more starter models from HuggingFace"
|
||||||
echo "7. command-line help "
|
echo "7. open the developer console"
|
||||||
|
echo "8. command-line help "
|
||||||
echo ""
|
echo ""
|
||||||
read -p "Please enter 1, 2, 3, 4, 5, 6 or 7: [2] " yn
|
read -p "Please enter 1, 2, 3, 4, 5, 6, 7 or 8: [2] " yn
|
||||||
choice=${yn:='2'}
|
choice=${yn:='2'}
|
||||||
case $choice in
|
case $choice in
|
||||||
1)
|
1)
|
||||||
@ -54,14 +55,17 @@ if [ "$0" != "bash" ]; then
|
|||||||
exec invokeai-merge --gui $@
|
exec invokeai-merge --gui $@
|
||||||
;;
|
;;
|
||||||
5)
|
5)
|
||||||
|
exec invokeai-configure --root ${INVOKEAI_ROOT}
|
||||||
|
;;
|
||||||
|
6)
|
||||||
|
exec invokeai-initial-models --root ${INVOKEAI_ROOT}
|
||||||
|
;;
|
||||||
|
7)
|
||||||
echo "Developer Console:"
|
echo "Developer Console:"
|
||||||
file_name=$(basename "${BASH_SOURCE[0]}")
|
file_name=$(basename "${BASH_SOURCE[0]}")
|
||||||
bash --init-file "$file_name"
|
bash --init-file "$file_name"
|
||||||
;;
|
;;
|
||||||
6)
|
8)
|
||||||
exec invokeai-configure --root ${INVOKEAI_ROOT}
|
|
||||||
;;
|
|
||||||
7)
|
|
||||||
exec invokeai --help
|
exec invokeai --help
|
||||||
;;
|
;;
|
||||||
*)
|
*)
|
||||||
|
@ -68,7 +68,7 @@ trinart-characters-2_0:
|
|||||||
width: 512
|
width: 512
|
||||||
height: 512
|
height: 512
|
||||||
recommended: False
|
recommended: False
|
||||||
ft-mse-improved-autoencoder-840000:
|
autoencoder-840000:
|
||||||
description: StabilityAI improved autoencoder fine-tuned for human faces. Improves legacy .ckpt models (335 MB)
|
description: StabilityAI improved autoencoder fine-tuned for human faces. Improves legacy .ckpt models (335 MB)
|
||||||
repo_id: stabilityai/sd-vae-ft-mse-original
|
repo_id: stabilityai/sd-vae-ft-mse-original
|
||||||
format: ckpt
|
format: ckpt
|
||||||
|
@ -820,6 +820,18 @@ def edit_model(model_name:str, gen, opt, completer):
|
|||||||
completer.set_line(info[attribute])
|
completer.set_line(info[attribute])
|
||||||
info[attribute] = input(f'{attribute}: ') or info[attribute]
|
info[attribute] = input(f'{attribute}: ') or info[attribute]
|
||||||
|
|
||||||
|
if info['format'] == 'diffusers':
|
||||||
|
vae = info.get('vae',dict(repo_id=None,path=None,subfolder=None))
|
||||||
|
completer.set_line(vae.get('repo_id') or 'stabilityai/sd-vae-ft-mse')
|
||||||
|
vae['repo_id'] = input('External VAE repo_id: ').strip() or None
|
||||||
|
if not vae['repo_id']:
|
||||||
|
completer.set_line(vae.get('path') or '')
|
||||||
|
vae['path'] = input('Path to a local diffusers VAE model (usually none): ').strip() or None
|
||||||
|
completer.set_line(vae.get('subfolder') or '')
|
||||||
|
vae['subfolder'] = input('Name of subfolder containing the VAE model (usually none): ').strip() or None
|
||||||
|
info['vae'] = vae
|
||||||
|
|
||||||
|
|
||||||
if new_name != model_name:
|
if new_name != model_name:
|
||||||
manager.del_model(model_name)
|
manager.del_model(model_name)
|
||||||
|
|
||||||
|
@ -10,37 +10,36 @@ print("Loading Python libraries...\n")
|
|||||||
import argparse
|
import argparse
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from tempfile import TemporaryFile
|
|
||||||
from typing import Union
|
|
||||||
from urllib import request
|
from urllib import request
|
||||||
|
|
||||||
import requests
|
|
||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL
|
|
||||||
from getpass_asterisk import getpass_asterisk
|
from getpass_asterisk import getpass_asterisk
|
||||||
from huggingface_hub import HfFolder, hf_hub_url
|
from huggingface_hub import HfFolder
|
||||||
from huggingface_hub import login as hf_hub_login
|
from huggingface_hub import login as hf_hub_login
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from transformers import (AutoProcessor, CLIPSegForImageSegmentation,
|
from transformers import (
|
||||||
CLIPTextModel, CLIPTokenizer)
|
AutoProcessor,
|
||||||
|
CLIPSegForImageSegmentation,
|
||||||
|
CLIPTextModel,
|
||||||
|
CLIPTokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
import invokeai.configs as configs
|
import invokeai.configs as configs
|
||||||
from ldm.invoke.devices import choose_precision, choose_torch_device
|
from ldm.invoke.config.initial_model_select import (
|
||||||
from ldm.invoke.generator.diffusers_pipeline import \
|
download_from_hf,
|
||||||
StableDiffusionGeneratorPipeline
|
select_and_download_models,
|
||||||
from ldm.invoke.globals import Globals, global_cache_dir, global_config_dir
|
yes_or_no,
|
||||||
|
)
|
||||||
|
from ldm.invoke.globals import Globals, global_config_dir
|
||||||
from ldm.invoke.readline import generic_completer
|
from ldm.invoke.readline import generic_completer
|
||||||
|
|
||||||
warnings.filterwarnings("ignore")
|
warnings.filterwarnings("ignore")
|
||||||
import torch
|
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
@ -104,125 +103,6 @@ Have fun!
|
|||||||
|
|
||||||
print(message)
|
print(message)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def yes_or_no(prompt: str, default_yes=True):
|
|
||||||
completer.set_options(["yes", "no"])
|
|
||||||
completer.complete_extensions(None) # turn off path-completion mode
|
|
||||||
default = "y" if default_yes else "n"
|
|
||||||
response = input(f"{prompt} [{default}] ") or default
|
|
||||||
if default_yes:
|
|
||||||
return response[0] not in ("n", "N")
|
|
||||||
else:
|
|
||||||
return response[0] in ("y", "Y")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def user_wants_to_download_weights() -> str:
|
|
||||||
"""
|
|
||||||
Returns one of "skip", "recommended" or "customized"
|
|
||||||
"""
|
|
||||||
print(
|
|
||||||
"""You can download and configure the weights files manually or let this
|
|
||||||
script do it for you. Manual installation is described at:
|
|
||||||
|
|
||||||
https://invoke-ai.github.io/InvokeAI/installation/020_INSTALL_MANUAL/
|
|
||||||
|
|
||||||
You may download the recommended models (about 15GB total), install all models (40 GB!!)
|
|
||||||
select a customized set, or completely skip this step.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
completer.set_options(["recommended", "customized", "skip"])
|
|
||||||
completer.complete_extensions(None) # turn off path-completion mode
|
|
||||||
selection = None
|
|
||||||
while selection is None:
|
|
||||||
choice = input(
|
|
||||||
"Download <r>ecommended models, <a>ll models, <c>ustomized list, or <s>kip this step? [r]: "
|
|
||||||
)
|
|
||||||
if choice.startswith(("r", "R")) or len(choice) == 0:
|
|
||||||
selection = "recommended"
|
|
||||||
elif choice.startswith(("c", "C")):
|
|
||||||
selection = "customized"
|
|
||||||
elif choice.startswith(("a", "A")):
|
|
||||||
selection = "all"
|
|
||||||
elif choice.startswith(("s", "S")):
|
|
||||||
selection = "skip"
|
|
||||||
return selection
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def select_datasets(action: str):
|
|
||||||
done = False
|
|
||||||
default_datasets = default_dataset()
|
|
||||||
while not done:
|
|
||||||
datasets = dict()
|
|
||||||
counter = 1
|
|
||||||
|
|
||||||
if action == "customized":
|
|
||||||
print(
|
|
||||||
"""
|
|
||||||
Choose the weight file(s) you wish to download. Before downloading you
|
|
||||||
will be given the option to view and change your selections.
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
for ds in Datasets.keys():
|
|
||||||
recommended = Datasets[ds].get("recommended", False)
|
|
||||||
r_str = "(recommended)" if recommended else ""
|
|
||||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {r_str}')
|
|
||||||
if yes_or_no(" Download?", default_yes=recommended):
|
|
||||||
datasets[ds] = True
|
|
||||||
counter += 1
|
|
||||||
else:
|
|
||||||
for ds in Datasets.keys():
|
|
||||||
if Datasets[ds].get("recommended", False):
|
|
||||||
datasets[ds] = True
|
|
||||||
counter += 1
|
|
||||||
|
|
||||||
print("The following weight files will be downloaded:")
|
|
||||||
counter = 1
|
|
||||||
for ds in datasets:
|
|
||||||
dflt = "*" if ds in default_datasets else ""
|
|
||||||
print(f" [{counter}] {ds}{dflt}")
|
|
||||||
counter += 1
|
|
||||||
print("* default")
|
|
||||||
ok_to_download = yes_or_no("Ok to download?")
|
|
||||||
if not ok_to_download:
|
|
||||||
if yes_or_no("Change your selection?"):
|
|
||||||
action = "customized"
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
done = True
|
|
||||||
else:
|
|
||||||
done = True
|
|
||||||
return datasets if ok_to_download else None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def recommended_datasets() -> dict:
|
|
||||||
datasets = dict()
|
|
||||||
for ds in Datasets.keys():
|
|
||||||
if Datasets[ds].get("recommended", False):
|
|
||||||
datasets[ds] = True
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def default_dataset() -> dict:
|
|
||||||
datasets = dict()
|
|
||||||
for ds in Datasets.keys():
|
|
||||||
if Datasets[ds].get("default", False):
|
|
||||||
datasets[ds] = True
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def all_datasets() -> dict:
|
|
||||||
datasets = dict()
|
|
||||||
for ds in Datasets.keys():
|
|
||||||
datasets[ds] = True
|
|
||||||
return datasets
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def HfLogin(access_token) -> str:
|
def HfLogin(access_token) -> str:
|
||||||
"""
|
"""
|
||||||
@ -242,7 +122,7 @@ def HfLogin(access_token) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# -------------------------------Authenticate against Hugging Face
|
# -------------------------------Authenticate against Hugging Face
|
||||||
def authenticate(yes_to_all=False):
|
def save_hf_token(yes_to_all=False):
|
||||||
print("** LICENSE AGREEMENT FOR WEIGHT FILES **")
|
print("** LICENSE AGREEMENT FOR WEIGHT FILES **")
|
||||||
print("=" * shutil.get_terminal_size()[0])
|
print("=" * shutil.get_terminal_size()[0])
|
||||||
print(
|
print(
|
||||||
@ -356,149 +236,6 @@ You may re-run the configuration script again in the future if you do not wish t
|
|||||||
return access_token
|
return access_token
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
# look for legacy model.ckpt in models directory and offer to
|
|
||||||
# normalize its name
|
|
||||||
def migrate_models_ckpt():
|
|
||||||
model_path = os.path.join(Globals.root, Model_dir, Weights_dir)
|
|
||||||
if not os.path.exists(os.path.join(model_path, "model.ckpt")):
|
|
||||||
return
|
|
||||||
new_name = Datasets["stable-diffusion-1.4"]["file"]
|
|
||||||
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
|
||||||
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
|
||||||
if rename:
|
|
||||||
print(f"model.ckpt => {new_name}")
|
|
||||||
os.replace(
|
|
||||||
os.path.join(model_path, "model.ckpt"), os.path.join(model_path, new_name)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_weight_datasets(
|
|
||||||
models: dict, access_token: str, precision: str = "float32"
|
|
||||||
):
|
|
||||||
migrate_models_ckpt()
|
|
||||||
successful = dict()
|
|
||||||
for mod in models.keys():
|
|
||||||
print(f"Downloading {mod}:")
|
|
||||||
successful[mod] = _download_repo_or_file(
|
|
||||||
Datasets[mod], access_token, precision=precision
|
|
||||||
)
|
|
||||||
return successful
|
|
||||||
|
|
||||||
|
|
||||||
def _download_repo_or_file(
|
|
||||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
|
||||||
) -> Path:
|
|
||||||
path = None
|
|
||||||
if mconfig["format"] == "ckpt":
|
|
||||||
path = _download_ckpt_weights(mconfig, access_token)
|
|
||||||
else:
|
|
||||||
path = _download_diffusion_weights(mconfig, access_token, precision=precision)
|
|
||||||
if "vae" in mconfig and "repo_id" in mconfig["vae"]:
|
|
||||||
_download_diffusion_weights(
|
|
||||||
mconfig["vae"], access_token, precision=precision
|
|
||||||
)
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
def _download_ckpt_weights(mconfig: DictConfig, access_token: str) -> Path:
|
|
||||||
repo_id = mconfig["repo_id"]
|
|
||||||
filename = mconfig["file"]
|
|
||||||
cache_dir = os.path.join(Globals.root, Model_dir, Weights_dir)
|
|
||||||
return hf_download_with_resume(
|
|
||||||
repo_id=repo_id,
|
|
||||||
model_dir=cache_dir,
|
|
||||||
model_name=filename,
|
|
||||||
access_token=access_token,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _download_diffusion_weights(
|
|
||||||
mconfig: DictConfig, access_token: str, precision: str = "float32"
|
|
||||||
):
|
|
||||||
repo_id = mconfig["repo_id"]
|
|
||||||
model_class = (
|
|
||||||
StableDiffusionGeneratorPipeline
|
|
||||||
if mconfig.get("format", None) == "diffusers"
|
|
||||||
else AutoencoderKL
|
|
||||||
)
|
|
||||||
extra_arg_list = [{"revision": "fp16"}, {}] if precision == "float16" else [{}]
|
|
||||||
path = None
|
|
||||||
for extra_args in extra_arg_list:
|
|
||||||
try:
|
|
||||||
path = download_from_hf(
|
|
||||||
model_class,
|
|
||||||
repo_id,
|
|
||||||
cache_subdir="diffusers",
|
|
||||||
safety_checker=None,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
except OSError as e:
|
|
||||||
if str(e).startswith("fp16 is not a valid"):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
print(f"An unexpected error occurred while downloading the model: {e})")
|
|
||||||
if path:
|
|
||||||
break
|
|
||||||
return path
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def hf_download_with_resume(
|
|
||||||
repo_id: str, model_dir: str, model_name: str, access_token: str = None
|
|
||||||
) -> Path:
|
|
||||||
model_dest = Path(os.path.join(model_dir, model_name))
|
|
||||||
os.makedirs(model_dir, exist_ok=True)
|
|
||||||
|
|
||||||
url = hf_hub_url(repo_id, model_name)
|
|
||||||
|
|
||||||
header = {"Authorization": f"Bearer {access_token}"} if access_token else {}
|
|
||||||
open_mode = "wb"
|
|
||||||
exist_size = 0
|
|
||||||
|
|
||||||
if os.path.exists(model_dest):
|
|
||||||
exist_size = os.path.getsize(model_dest)
|
|
||||||
header["Range"] = f"bytes={exist_size}-"
|
|
||||||
open_mode = "ab"
|
|
||||||
|
|
||||||
resp = requests.get(url, headers=header, stream=True)
|
|
||||||
total = int(resp.headers.get("content-length", 0))
|
|
||||||
|
|
||||||
if (
|
|
||||||
resp.status_code == 416
|
|
||||||
): # "range not satisfiable", which means nothing to return
|
|
||||||
print(f"* {model_name}: complete file found. Skipping.")
|
|
||||||
return model_dest
|
|
||||||
elif resp.status_code != 200:
|
|
||||||
print(f"** An error occurred during downloading {model_name}: {resp.reason}")
|
|
||||||
elif exist_size > 0:
|
|
||||||
print(f"* {model_name}: partial file found. Resuming...")
|
|
||||||
else:
|
|
||||||
print(f"* {model_name}: Downloading...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
if total < 2000:
|
|
||||||
print(f"*** ERROR DOWNLOADING {model_name}: {resp.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
with open(model_dest, open_mode) as file, tqdm(
|
|
||||||
desc=model_name,
|
|
||||||
initial=exist_size,
|
|
||||||
total=total + exist_size,
|
|
||||||
unit="iB",
|
|
||||||
unit_scale=True,
|
|
||||||
unit_divisor=1000,
|
|
||||||
) as bar:
|
|
||||||
for data in resp.iter_content(chunk_size=1024):
|
|
||||||
size = file.write(data)
|
|
||||||
bar.update(size)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"An error occurred while downloading {model_name}: {str(e)}")
|
|
||||||
return None
|
|
||||||
return model_dest
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
def download_with_progress_bar(model_url: str, model_dest: str, label: str = "the"):
|
||||||
try:
|
try:
|
||||||
@ -517,125 +254,6 @@ def download_with_progress_bar(model_url: str, model_dest: str, label: str = "th
|
|||||||
print(f"Error downloading {label} model")
|
print(f"Error downloading {label} model")
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def update_config_file(successfully_downloaded: dict, opt: dict):
|
|
||||||
config_file = (
|
|
||||||
Path(opt.config_file) if opt.config_file is not None else Default_config_file
|
|
||||||
)
|
|
||||||
|
|
||||||
# In some cases (incomplete setup, etc), the default configs directory might be missing.
|
|
||||||
# Create it if it doesn't exist.
|
|
||||||
# this check is ignored if opt.config_file is specified - user is assumed to know what they
|
|
||||||
# are doing if they are passing a custom config file from elsewhere.
|
|
||||||
if config_file is Default_config_file and not config_file.parent.exists():
|
|
||||||
configs_src = Dataset_path.parent
|
|
||||||
configs_dest = Default_config_file.parent
|
|
||||||
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
|
|
||||||
|
|
||||||
yaml = new_config_file_contents(successfully_downloaded, config_file, opt)
|
|
||||||
|
|
||||||
try:
|
|
||||||
backup = None
|
|
||||||
if os.path.exists(config_file):
|
|
||||||
print(
|
|
||||||
f"** {config_file.name} exists. Renaming to {config_file.stem}.yaml.orig"
|
|
||||||
)
|
|
||||||
backup = config_file.with_suffix(".yaml.orig")
|
|
||||||
## Ugh. Windows is unable to overwrite an existing backup file, raises a WinError 183
|
|
||||||
if sys.platform == "win32" and backup.is_file():
|
|
||||||
backup.unlink()
|
|
||||||
config_file.rename(backup)
|
|
||||||
|
|
||||||
with TemporaryFile() as tmp:
|
|
||||||
tmp.write(Config_preamble.encode())
|
|
||||||
tmp.write(yaml.encode())
|
|
||||||
|
|
||||||
with open(str(config_file.expanduser().resolve()), "wb") as new_config:
|
|
||||||
tmp.seek(0)
|
|
||||||
new_config.write(tmp.read())
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"**Error creating config file {config_file}: {str(e)} **")
|
|
||||||
if backup is not None:
|
|
||||||
print("restoring previous config file")
|
|
||||||
## workaround, for WinError 183, see above
|
|
||||||
if sys.platform == "win32" and config_file.is_file():
|
|
||||||
config_file.unlink()
|
|
||||||
backup.rename(config_file)
|
|
||||||
return
|
|
||||||
|
|
||||||
print(f"Successfully created new configuration file {config_file}")
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def new_config_file_contents(successfully_downloaded: dict, config_file: Path, opt: dict) -> str:
|
|
||||||
if config_file.exists():
|
|
||||||
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
|
|
||||||
else:
|
|
||||||
conf = OmegaConf.create()
|
|
||||||
|
|
||||||
default_selected = None
|
|
||||||
for model in successfully_downloaded:
|
|
||||||
|
|
||||||
# a bit hacky - what we are doing here is seeing whether a checkpoint
|
|
||||||
# version of the model was previously defined, and whether the current
|
|
||||||
# model is a diffusers (indicated with a path)
|
|
||||||
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
|
|
||||||
offer_to_delete_weights(model, conf[model], opt.yes_to_all)
|
|
||||||
|
|
||||||
stanza = {}
|
|
||||||
mod = Datasets[model]
|
|
||||||
stanza["description"] = mod["description"]
|
|
||||||
stanza["repo_id"] = mod["repo_id"]
|
|
||||||
stanza["format"] = mod["format"]
|
|
||||||
# diffusers don't need width and height (probably .ckpt doesn't either)
|
|
||||||
# so we no longer require these in INITIAL_MODELS.yaml
|
|
||||||
if "width" in mod:
|
|
||||||
stanza["width"] = mod["width"]
|
|
||||||
if "height" in mod:
|
|
||||||
stanza["height"] = mod["height"]
|
|
||||||
if "file" in mod:
|
|
||||||
stanza["weights"] = os.path.relpath(
|
|
||||||
successfully_downloaded[model], start=Globals.root
|
|
||||||
)
|
|
||||||
stanza["config"] = os.path.normpath(os.path.join(SD_Configs, mod["config"]))
|
|
||||||
if "vae" in mod:
|
|
||||||
if "file" in mod["vae"]:
|
|
||||||
stanza["vae"] = os.path.normpath(
|
|
||||||
os.path.join(Model_dir, Weights_dir, mod["vae"]["file"])
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
stanza["vae"] = mod["vae"]
|
|
||||||
if mod.get("default", False):
|
|
||||||
stanza["default"] = True
|
|
||||||
default_selected = True
|
|
||||||
|
|
||||||
conf[model] = stanza
|
|
||||||
|
|
||||||
# if no default model was chosen, then we select the first
|
|
||||||
# one in the list
|
|
||||||
if not default_selected:
|
|
||||||
conf[list(successfully_downloaded.keys())[0]]["default"] = True
|
|
||||||
|
|
||||||
return OmegaConf.to_yaml(conf)
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def offer_to_delete_weights(model_name: str, conf_stanza: dict, yes_to_all: bool):
|
|
||||||
if not (weights := conf_stanza.get('weights')):
|
|
||||||
return
|
|
||||||
if re.match('/VAE/',conf_stanza.get('config')):
|
|
||||||
return
|
|
||||||
if yes_to_all or \
|
|
||||||
yes_or_no(f'\n** The checkpoint version of {model_name} is superseded by the diffusers version. Delete the original file {weights}?', default_yes=False):
|
|
||||||
weights = Path(weights)
|
|
||||||
if not weights.is_absolute():
|
|
||||||
weights = Path(Globals.root) / weights
|
|
||||||
try:
|
|
||||||
weights.unlink()
|
|
||||||
except OSError as e:
|
|
||||||
print(str(e))
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
# this will preload the Bert tokenizer fles
|
# this will preload the Bert tokenizer fles
|
||||||
def download_bert():
|
def download_bert():
|
||||||
@ -652,22 +270,6 @@ def download_bert():
|
|||||||
print("...success", file=sys.stderr)
|
print("...success", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
|
||||||
def download_from_hf(
|
|
||||||
model_class: object, model_name: str, cache_subdir: Path = Path("hub"), **kwargs
|
|
||||||
):
|
|
||||||
print("", file=sys.stderr) # to prevent tqdm from overwriting
|
|
||||||
path = global_cache_dir(cache_subdir)
|
|
||||||
model = model_class.from_pretrained(
|
|
||||||
model_name,
|
|
||||||
cache_dir=path,
|
|
||||||
resume_download=True,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
model_name = '--'.join(('models',*model_name.split('/')))
|
|
||||||
return path / model_name if model else None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_clip():
|
def download_clip():
|
||||||
print("Installing CLIP model (ignore deprecation errors)...", file=sys.stderr)
|
print("Installing CLIP model (ignore deprecation errors)...", file=sys.stderr)
|
||||||
@ -744,8 +346,9 @@ def download_clipseg():
|
|||||||
def download_safety_checker():
|
def download_safety_checker():
|
||||||
print("Installing model for NSFW content detection...", file=sys.stderr)
|
print("Installing model for NSFW content detection...", file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import \
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||||
StableDiffusionSafetyChecker
|
StableDiffusionSafetyChecker,
|
||||||
|
)
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
print("Error installing NSFW checker model:")
|
print("Error installing NSFW checker model:")
|
||||||
@ -759,52 +362,6 @@ def download_safety_checker():
|
|||||||
print("...success", file=sys.stderr)
|
print("...success", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
|
||||||
def download_weights(opt: dict) -> Union[str, None]:
|
|
||||||
precision = (
|
|
||||||
"float32"
|
|
||||||
if opt.full_precision
|
|
||||||
else choose_precision(torch.device(choose_torch_device()))
|
|
||||||
)
|
|
||||||
|
|
||||||
if opt.yes_to_all:
|
|
||||||
models = default_dataset() if opt.default_only else recommended_datasets()
|
|
||||||
access_token = authenticate(opt.yes_to_all)
|
|
||||||
if len(models) > 0:
|
|
||||||
successfully_downloaded = download_weight_datasets(
|
|
||||||
models, access_token, precision=precision
|
|
||||||
)
|
|
||||||
update_config_file(successfully_downloaded, opt)
|
|
||||||
return
|
|
||||||
|
|
||||||
else:
|
|
||||||
choice = user_wants_to_download_weights()
|
|
||||||
|
|
||||||
if choice == "recommended":
|
|
||||||
models = recommended_datasets()
|
|
||||||
elif choice == "all":
|
|
||||||
models = all_datasets()
|
|
||||||
elif choice == "customized":
|
|
||||||
models = select_datasets(choice)
|
|
||||||
if models is None and yes_or_no("Quit?", default_yes=False):
|
|
||||||
sys.exit(0)
|
|
||||||
else: # 'skip'
|
|
||||||
return
|
|
||||||
|
|
||||||
access_token = authenticate()
|
|
||||||
if access_token is not None:
|
|
||||||
HfFolder.save_token(access_token)
|
|
||||||
|
|
||||||
print("\n** DOWNLOADING WEIGHTS **")
|
|
||||||
successfully_downloaded = download_weight_datasets(
|
|
||||||
models, access_token, precision=precision
|
|
||||||
)
|
|
||||||
|
|
||||||
update_config_file(successfully_downloaded, opt)
|
|
||||||
if len(successfully_downloaded) < len(models):
|
|
||||||
return "some of the model weights downloads were not successful"
|
|
||||||
|
|
||||||
|
|
||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def get_root(root: str = None) -> str:
|
def get_root(root: str = None) -> str:
|
||||||
if root:
|
if root:
|
||||||
@ -951,13 +508,6 @@ class ProgressBar:
|
|||||||
# -------------------------------------
|
# -------------------------------------
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
parser = argparse.ArgumentParser(description="InvokeAI model downloader")
|
||||||
parser.add_argument(
|
|
||||||
"--interactive",
|
|
||||||
dest="interactive",
|
|
||||||
action=argparse.BooleanOptionalAction,
|
|
||||||
default=True,
|
|
||||||
help="run in interactive mode (default) - DEPRECATED",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--skip-sd-weights",
|
"--skip-sd-weights",
|
||||||
dest="skip_sd_weights",
|
dest="skip_sd_weights",
|
||||||
@ -1005,26 +555,16 @@ def main():
|
|||||||
# setting a global here
|
# setting a global here
|
||||||
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
Globals.root = os.path.expanduser(get_root(opt.root) or "")
|
||||||
|
|
||||||
|
errors = set()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# We check for to see if the runtime directory is correctly initialized.
|
# We check for to see if the runtime directory is correctly initialized.
|
||||||
if Globals.root == "" or not os.path.exists(
|
if Globals.root == "" or not os.path.exists(
|
||||||
os.path.join(Globals.root, "invokeai.init")
|
os.path.join(Globals.root, "invokeai.init")
|
||||||
):
|
):
|
||||||
initialize_rootdir(Globals.root, opt.yes_to_all)
|
initialize_rootdir(Globals.root, opt.yes_to_all)
|
||||||
|
save_hf_token(opt.yes_to_all)
|
||||||
|
|
||||||
# Optimistically try to download all required assets. If any errors occur, add them and proceed anyway.
|
|
||||||
errors = set()
|
|
||||||
|
|
||||||
if not opt.interactive:
|
|
||||||
print(
|
|
||||||
"WARNING: The --(no)-interactive argument is deprecated and will be removed. Use --skip-sd-weights."
|
|
||||||
)
|
|
||||||
opt.skip_sd_weights = True
|
|
||||||
if opt.skip_sd_weights:
|
|
||||||
print("** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
|
||||||
else:
|
|
||||||
print("** DOWNLOADING DIFFUSION WEIGHTS **")
|
|
||||||
errors.add(download_weights(opt))
|
|
||||||
print("\n** DOWNLOADING SUPPORT MODELS **")
|
print("\n** DOWNLOADING SUPPORT MODELS **")
|
||||||
download_bert()
|
download_bert()
|
||||||
download_clip()
|
download_clip()
|
||||||
@ -1033,6 +573,13 @@ def main():
|
|||||||
download_codeformer()
|
download_codeformer()
|
||||||
download_clipseg()
|
download_clipseg()
|
||||||
download_safety_checker()
|
download_safety_checker()
|
||||||
|
|
||||||
|
if opt.skip_sd_weights:
|
||||||
|
print("** SKIPPING DIFFUSION WEIGHTS DOWNLOAD PER USER REQUEST **")
|
||||||
|
else:
|
||||||
|
print("** DOWNLOADING DIFFUSION WEIGHTS **")
|
||||||
|
errors.add(select_and_download_models(opt))
|
||||||
|
|
||||||
postscript(errors=errors)
|
postscript(errors=errors)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
print("\nGoodbye! Come back soon.")
|
print("\nGoodbye! Come back soon.")
|
||||||
|
@ -25,19 +25,20 @@ import torch
|
|||||||
import transformers
|
import transformers
|
||||||
from diffusers import AutoencoderKL
|
from diffusers import AutoencoderKL
|
||||||
from diffusers import logging as dlogging
|
from diffusers import logging as dlogging
|
||||||
from diffusers.utils.logging import (get_verbosity, set_verbosity,
|
from diffusers.utils.logging import get_verbosity, set_verbosity, set_verbosity_error
|
||||||
set_verbosity_error)
|
|
||||||
from huggingface_hub import scan_cache_dir
|
from huggingface_hub import scan_cache_dir
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.dictconfig import DictConfig
|
from omegaconf.dictconfig import DictConfig
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
from ldm.invoke.generator.diffusers_pipeline import \
|
from ldm.invoke.generator.diffusers_pipeline import StableDiffusionGeneratorPipeline
|
||||||
StableDiffusionGeneratorPipeline
|
from ldm.invoke.globals import (
|
||||||
from ldm.invoke.globals import (Globals, global_autoscan_dir, global_cache_dir,
|
Globals,
|
||||||
global_models_dir)
|
global_autoscan_dir,
|
||||||
from ldm.util import (ask_user, download_with_progress_bar,
|
global_cache_dir,
|
||||||
instantiate_from_config)
|
global_models_dir,
|
||||||
|
)
|
||||||
|
from ldm.util import ask_user, download_with_progress_bar, instantiate_from_config
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS = 2
|
DEFAULT_MAX_MODELS = 2
|
||||||
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
VAE_TO_REPO_ID = { # hack, see note in convert_and_import()
|
||||||
@ -374,8 +375,9 @@ class ModelManager(object):
|
|||||||
print(
|
print(
|
||||||
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
f">> Converting legacy checkpoint {model_name} into a diffusers model..."
|
||||||
)
|
)
|
||||||
from ldm.invoke.ckpt_to_diffuser import \
|
from ldm.invoke.ckpt_to_diffuser import (
|
||||||
load_pipeline_from_original_stable_diffusion_ckpt
|
load_pipeline_from_original_stable_diffusion_ckpt,
|
||||||
|
)
|
||||||
|
|
||||||
if vae_config := self._choose_diffusers_vae(model_name):
|
if vae_config := self._choose_diffusers_vae(model_name):
|
||||||
vae = self._load_vae(vae_config)
|
vae = self._load_vae(vae_config)
|
||||||
@ -495,7 +497,7 @@ class ModelManager(object):
|
|||||||
safety_checker=None, local_files_only=not Globals.internet_available
|
safety_checker=None, local_files_only=not Globals.internet_available
|
||||||
)
|
)
|
||||||
if "vae" in mconfig and mconfig["vae"] is not None:
|
if "vae" in mconfig and mconfig["vae"] is not None:
|
||||||
vae = self._load_vae(mconfig["vae"])
|
if vae := self._load_vae(mconfig["vae"]):
|
||||||
pipeline_args.update(vae=vae)
|
pipeline_args.update(vae=vae)
|
||||||
if not isinstance(name_or_path, Path):
|
if not isinstance(name_or_path, Path):
|
||||||
pipeline_args.update(cache_dir=global_cache_dir("diffusers"))
|
pipeline_args.update(cache_dir=global_cache_dir("diffusers"))
|
||||||
@ -551,7 +553,7 @@ class ModelManager(object):
|
|||||||
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
f'"{model_name}" is not a known model name. Please check your models.yaml file'
|
||||||
)
|
)
|
||||||
|
|
||||||
if "path" in mconfig:
|
if "path" in mconfig and mconfig["path"] is not None:
|
||||||
path = Path(mconfig["path"])
|
path = Path(mconfig["path"])
|
||||||
if not path.is_absolute():
|
if not path.is_absolute():
|
||||||
path = Path(Globals.root, path).resolve()
|
path = Path(Globals.root, path).resolve()
|
||||||
@ -762,7 +764,7 @@ class ModelManager(object):
|
|||||||
model_description = model_description or "Optimized version of {model_name}"
|
model_description = model_description or "Optimized version of {model_name}"
|
||||||
print(f">> Optimizing {model_name} (30-60s)")
|
print(f">> Optimizing {model_name} (30-60s)")
|
||||||
try:
|
try:
|
||||||
# By passing the specified VAE too the conversion function, the autoencoder
|
# By passing the specified VAE to the conversion function, the autoencoder
|
||||||
# will be built into the model rather than tacked on afterward via the config file
|
# will be built into the model rather than tacked on afterward via the config file
|
||||||
vae_model = self._load_vae(vae) if vae else None
|
vae_model = self._load_vae(vae) if vae else None
|
||||||
convert_ckpt_to_diffuser(
|
convert_ckpt_to_diffuser(
|
||||||
@ -789,7 +791,9 @@ class ModelManager(object):
|
|||||||
print(">> Conversion succeeded")
|
print(">> Conversion succeeded")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"** Conversion failed: {str(e)}")
|
print(f"** Conversion failed: {str(e)}")
|
||||||
print("** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)")
|
print(
|
||||||
|
"** If you are trying to convert an inpainting or 2.X model, please indicate the correct config file (e.g. v1-inpainting-inference.yaml)"
|
||||||
|
)
|
||||||
|
|
||||||
return new_config
|
return new_config
|
||||||
|
|
||||||
@ -1102,7 +1106,12 @@ class ModelManager(object):
|
|||||||
|
|
||||||
def _load_vae(self, vae_config) -> AutoencoderKL:
|
def _load_vae(self, vae_config) -> AutoencoderKL:
|
||||||
vae_args = {}
|
vae_args = {}
|
||||||
|
try:
|
||||||
name_or_path = self.model_name_or_path(vae_config)
|
name_or_path = self.model_name_or_path(vae_config)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
if name_or_path is None:
|
||||||
|
return None
|
||||||
using_fp16 = self.precision == "float16"
|
using_fp16 = self.precision == "float16"
|
||||||
|
|
||||||
vae_args.update(
|
vae_args.update(
|
||||||
|
@ -108,6 +108,7 @@ dependencies = [
|
|||||||
"invokeai-configure" = "ldm.invoke.config.invokeai_configure:main"
|
"invokeai-configure" = "ldm.invoke.config.invokeai_configure:main"
|
||||||
"invokeai-merge" = "ldm.invoke.merge_diffusers:main" # note name munging
|
"invokeai-merge" = "ldm.invoke.merge_diffusers:main" # note name munging
|
||||||
"invokeai-ti" = "ldm.invoke.training.textual_inversion:main"
|
"invokeai-ti" = "ldm.invoke.training.textual_inversion:main"
|
||||||
|
"invokeai-initial-models" = "ldm.invoke.config.initial_model_select:main"
|
||||||
|
|
||||||
[project.urls]
|
[project.urls]
|
||||||
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
"Homepage" = "https://invoke-ai.github.io/InvokeAI/"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user