(config) make user aware of any problems downloading models

also implement a generic way of reporting issues at the end of installation
This commit is contained in:
Eugene Brodsky 2022-11-27 05:26:38 -05:00 committed by Lincoln Stein
parent ed6194351c
commit 76633f500a

View File

@ -18,6 +18,7 @@ from tqdm import tqdm
from omegaconf import OmegaConf
from huggingface_hub import HfFolder, hf_hub_url
from pathlib import Path
from typing import Union
from getpass_asterisk import getpass_asterisk
from transformers import CLIPTokenizer, CLIPTextModel
from ldm.invoke.globals import Globals
@ -62,9 +63,9 @@ this program and resume later.\n'''
)
#--------------------------------------------
def postscript():
print(
'''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands:
def postscript(errors: None):
if not any(errors):
message='''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands:
Web version:
python scripts/invoke.py --web (connect to http://localhost:9090)
Command-line version:
@ -77,7 +78,14 @@ automated installation script, execute "invoke.sh" (Linux/Mac) or
Have fun!
'''
)
else:
message=f"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
for err in errors:
message += f"\t - {err}\n"
message += "Please check the logs above and correct any issues."
print(message)
#---------------------------------------------
def yes_or_no(prompt:str, default_yes=True):
@ -521,6 +529,7 @@ def download_safety_checker():
print('...success',file=sys.stderr)
#-------------------------------------
def download_weights(opt:dict) -> Union[str, None]:
# Authenticate to Huggingface using environment variables.
# If successful, authentication will persist for either interactive or non-interactive use.
# Default env var expected by HuggingFace is HUGGING_FACE_HUB_TOKEN.
@ -537,7 +546,8 @@ def download_safety_checker():
return
else:
print('** Cannot download models because no Hugging Face access token could be found. Please re-run without --yes')
return
return "could not download model weights from Huggingface due to missing or invalid access token"
else:
choice = user_wants_to_download_weights()
@ -558,6 +568,8 @@ def download_safety_checker():
print('\n** DOWNLOADING WEIGHTS **')
successfully_downloaded = download_weight_datasets(models, access_token)
update_config_file(successfully_downloaded,opt)
if len(successfully_downloaded) < len(models):
return "some of the model weights downloads were not successful"
#-------------------------------------
def get_root(root:str=None)->str:
@ -746,9 +758,12 @@ def main():
or not os.path.exists(os.path.join(Globals.root,'configs/stable-diffusion/v1-inference.yaml')):
initialize_rootdir(Globals.root,opt.yes_to_all)
# Optimistically try to download all required assets. If any errors occur, add them and proceed anyway.
errors=set()
if opt.interactive:
print('** DOWNLOADING DIFFUSION WEIGHTS **')
download_weights(opt)
errors.add(download_weights(opt))
print('\n** DOWNLOADING SUPPORT MODELS **')
download_bert()
download_clip()
@ -757,7 +772,7 @@ def main():
download_codeformer()
download_clipseg()
download_safety_checker()
postscript()
postscript(errors=errors)
except KeyboardInterrupt:
print('\nGoodbye! Come back soon.')
except Exception as e: