Bring main up to 2.3.0-rc6 (#2563)

This bumps up the version number, and also applies a hotfix to the
configure script to fix the problem described in PR #2562
This commit is contained in:
Lincoln Stein 2023-02-07 18:02:13 -05:00 committed by GitHub
commit 9c2b9af3a8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 10 deletions

View File

@ -1 +1 @@
__version__='2.3.0-rc5' __version__='2.3.0-rc6'

View File

@ -10,6 +10,7 @@ print("Loading Python libraries...\n")
import argparse import argparse
import io import io
import os import os
import re
import shutil import shutil
import sys import sys
import traceback import traceback
@ -320,7 +321,7 @@ You may re-run the configuration script again in the future if you do not wish t
while again: while again:
try: try:
access_token = getpass_asterisk.getpass_asterisk(prompt="HF Token ") access_token = getpass_asterisk.getpass_asterisk(prompt="HF Token ")
if access_token is None or len(access_token)==0: if access_token is None or len(access_token) == 0:
raise EOFError raise EOFError
HfLogin(access_token) HfLogin(access_token)
access_token = HfFolder.get_token() access_token = HfFolder.get_token()
@ -379,7 +380,7 @@ def download_weight_datasets(
migrate_models_ckpt() migrate_models_ckpt()
successful = dict() successful = dict()
for mod in models.keys(): for mod in models.keys():
print(f"{mod}...", file=sys.stderr, end="") print(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file( successful[mod] = _download_repo_or_file(
Datasets[mod], access_token, precision=precision Datasets[mod], access_token, precision=precision
) )
@ -532,7 +533,7 @@ def update_config_file(successfully_downloaded: dict, opt: dict):
configs_dest = Default_config_file.parent configs_dest = Default_config_file.parent
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True) shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
yaml = new_config_file_contents(successfully_downloaded, config_file) yaml = new_config_file_contents(successfully_downloaded, config_file, opt)
try: try:
backup = None backup = None
@ -568,7 +569,7 @@ def update_config_file(successfully_downloaded: dict, opt: dict):
# --------------------------------------------- # ---------------------------------------------
def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -> str: def new_config_file_contents(successfully_downloaded: dict, config_file: Path, opt: dict) -> str:
if config_file.exists(): if config_file.exists():
conf = OmegaConf.load(str(config_file.expanduser().resolve())) conf = OmegaConf.load(str(config_file.expanduser().resolve()))
else: else:
@ -576,7 +577,14 @@ def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -
default_selected = None default_selected = None
for model in successfully_downloaded: for model in successfully_downloaded:
stanza = conf[model] if model in conf else {}
# a bit hacky - what we are doing here is seeing whether a checkpoint
# version of the model was previously defined, and whether the current
# model is a diffusers (indicated with a path)
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
offer_to_delete_weights(model, conf[model], opt.yes_to_all)
stanza = {}
mod = Datasets[model] mod = Datasets[model]
stanza["description"] = mod["description"] stanza["description"] = mod["description"]
stanza["repo_id"] = mod["repo_id"] stanza["repo_id"] = mod["repo_id"]
@ -599,8 +607,8 @@ def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -
) )
else: else:
stanza["vae"] = mod["vae"] stanza["vae"] = mod["vae"]
if mod.get('default',False): if mod.get("default", False):
stanza['default'] = True stanza["default"] = True
default_selected = True default_selected = True
conf[model] = stanza conf[model] = stanza
@ -612,7 +620,22 @@ def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -
return OmegaConf.to_yaml(conf) return OmegaConf.to_yaml(conf)
# ---------------------------------------------
def offer_to_delete_weights(model_name: str, conf_stanza: dict, yes_to_all: bool):
if not (weights := conf_stanza.get('weights')):
return
if re.match('/VAE/',conf_stanza.get('config')):
return
if yes_to_all or \
yes_or_no(f'\n** The checkpoint version of {model_name} is superseded by the diffusers version. Delete the original file {weights}?', default_yes=False):
weights = Path(weights)
if not weights.is_absolute():
weights = Path(Globals.root) / weights
try:
weights.unlink()
except OSError as e:
print(str(e))
# --------------------------------------------- # ---------------------------------------------
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
def download_bert(): def download_bert():
@ -641,7 +664,8 @@ def download_from_hf(
resume_download=True, resume_download=True,
**kwargs, **kwargs,
) )
return path if model else None model_name = '--'.join(('models',*model_name.split('/')))
return path / model_name if model else None
# --------------------------------------------- # ---------------------------------------------