Merge branch 'invoke-ai:main' into main

This commit is contained in:
Matthias Wild 2023-02-08 00:25:49 +01:00 committed by GitHub
commit 6b4a06c3fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 11 deletions

View File

@ -1 +1 @@
__version__='2.3.0-rc5'
__version__='2.3.0-rc6'

View File

@ -10,6 +10,7 @@ print("Loading Python libraries...\n")
import argparse
import io
import os
import re
import shutil
import sys
import traceback
@ -320,7 +321,7 @@ You may re-run the configuration script again in the future if you do not wish t
while again:
try:
access_token = getpass_asterisk.getpass_asterisk(prompt="HF Token ")
if access_token is None or len(access_token)==0:
if access_token is None or len(access_token) == 0:
raise EOFError
HfLogin(access_token)
access_token = HfFolder.get_token()
@ -379,7 +380,7 @@ def download_weight_datasets(
migrate_models_ckpt()
successful = dict()
for mod in models.keys():
print(f"{mod}...", file=sys.stderr, end="")
print(f"Downloading {mod}:")
successful[mod] = _download_repo_or_file(
Datasets[mod], access_token, precision=precision
)
@ -532,7 +533,7 @@ def update_config_file(successfully_downloaded: dict, opt: dict):
configs_dest = Default_config_file.parent
shutil.copytree(configs_src, configs_dest, dirs_exist_ok=True)
yaml = new_config_file_contents(successfully_downloaded, config_file)
yaml = new_config_file_contents(successfully_downloaded, config_file, opt)
try:
backup = None
@ -568,7 +569,7 @@ def update_config_file(successfully_downloaded: dict, opt: dict):
# ---------------------------------------------
def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -> str:
def new_config_file_contents(successfully_downloaded: dict, config_file: Path, opt: dict) -> str:
if config_file.exists():
conf = OmegaConf.load(str(config_file.expanduser().resolve()))
else:
@ -576,7 +577,14 @@ def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -
default_selected = None
for model in successfully_downloaded:
stanza = conf[model] if model in conf else {}
# a bit hacky - what we are doing here is seeing whether a checkpoint
# version of the model was previously defined, and whether the current
# model is a diffusers (indicated with a path)
if conf.get(model) and Path(successfully_downloaded[model]).is_dir():
offer_to_delete_weights(model, conf[model], opt.yes_to_all)
stanza = {}
mod = Datasets[model]
stanza["description"] = mod["description"]
stanza["repo_id"] = mod["repo_id"]
@ -599,8 +607,8 @@ def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -
)
else:
stanza["vae"] = mod["vae"]
if mod.get('default',False):
stanza['default'] = True
if mod.get("default", False):
stanza["default"] = True
default_selected = True
conf[model] = stanza
@ -612,7 +620,22 @@ def new_config_file_contents(successfully_downloaded: dict, config_file: Path) -
return OmegaConf.to_yaml(conf)
# ---------------------------------------------
def offer_to_delete_weights(model_name: str, conf_stanza: dict, yes_to_all: bool):
if not (weights := conf_stanza.get('weights')):
return
if re.match('/VAE/',conf_stanza.get('config')):
return
if yes_to_all or \
yes_or_no(f'\n** The checkpoint version of {model_name} is superseded by the diffusers version. Delete the original file {weights}?', default_yes=False):
weights = Path(weights)
if not weights.is_absolute():
weights = Path(Globals.root) / weights
try:
weights.unlink()
except OSError as e:
print(str(e))
# ---------------------------------------------
# this will preload the Bert tokenizer fles
def download_bert():
@ -641,7 +664,8 @@ def download_from_hf(
resume_download=True,
**kwargs,
)
return path if model else None
model_name = '--'.join(('models',*model_name.split('/')))
return path / model_name if model else None
# ---------------------------------------------

View File

@ -317,7 +317,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
pass
else:
self.enable_attention_slicing(slice_size='auto')
self.enable_attention_slicing(slice_size='max')
def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
conditioning_data: ConditioningData,