(config) make user aware of any problems downloading models

also implement a generic way of reporting issues at the end of installation
This commit is contained in:
Eugene Brodsky 2022-11-27 05:26:38 -05:00 committed by Lincoln Stein
parent ed6194351c
commit 76633f500a

View File

@ -18,6 +18,7 @@ from tqdm import tqdm
from omegaconf import OmegaConf from omegaconf import OmegaConf
from huggingface_hub import HfFolder, hf_hub_url from huggingface_hub import HfFolder, hf_hub_url
from pathlib import Path from pathlib import Path
from typing import Union
from getpass_asterisk import getpass_asterisk from getpass_asterisk import getpass_asterisk
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
@ -62,9 +63,9 @@ this program and resume later.\n'''
) )
#-------------------------------------------- #--------------------------------------------
def postscript(): def postscript(errors: None):
print( if not any(errors):
'''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands: message='''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands:
Web version: Web version:
python scripts/invoke.py --web (connect to http://localhost:9090) python scripts/invoke.py --web (connect to http://localhost:9090)
Command-line version: Command-line version:
@ -77,7 +78,14 @@ automated installation script, execute "invoke.sh" (Linux/Mac) or
Have fun! Have fun!
''' '''
)
else:
message=f"\n** There were errors during installation. It is possible some of the models were not fully downloaded.\n"
for err in errors:
message += f"\t - {err}\n"
message += "Please check the logs above and correct any issues."
print(message)
#--------------------------------------------- #---------------------------------------------
def yes_or_no(prompt:str, default_yes=True): def yes_or_no(prompt:str, default_yes=True):
@ -521,6 +529,7 @@ def download_safety_checker():
print('...success',file=sys.stderr) print('...success',file=sys.stderr)
#------------------------------------- #-------------------------------------
def download_weights(opt:dict) -> Union[str, None]:
# Authenticate to Huggingface using environment variables. # Authenticate to Huggingface using environment variables.
# If successful, authentication will persist for either interactive or non-interactive use. # If successful, authentication will persist for either interactive or non-interactive use.
# Default env var expected by HuggingFace is HUGGING_FACE_HUB_TOKEN. # Default env var expected by HuggingFace is HUGGING_FACE_HUB_TOKEN.
@ -537,7 +546,8 @@ def download_safety_checker():
return return
else: else:
print('** Cannot download models because no Hugging Face access token could be found. Please re-run without --yes') print('** Cannot download models because no Hugging Face access token could be found. Please re-run without --yes')
return return "could not download model weights from Huggingface due to missing or invalid access token"
else: else:
choice = user_wants_to_download_weights() choice = user_wants_to_download_weights()
@ -558,6 +568,8 @@ def download_safety_checker():
print('\n** DOWNLOADING WEIGHTS **') print('\n** DOWNLOADING WEIGHTS **')
successfully_downloaded = download_weight_datasets(models, access_token) successfully_downloaded = download_weight_datasets(models, access_token)
update_config_file(successfully_downloaded,opt) update_config_file(successfully_downloaded,opt)
if len(successfully_downloaded) < len(models):
return "some of the model weights downloads were not successful"
#------------------------------------- #-------------------------------------
def get_root(root:str=None)->str: def get_root(root:str=None)->str:
@ -746,9 +758,12 @@ def main():
or not os.path.exists(os.path.join(Globals.root,'configs/stable-diffusion/v1-inference.yaml')): or not os.path.exists(os.path.join(Globals.root,'configs/stable-diffusion/v1-inference.yaml')):
initialize_rootdir(Globals.root,opt.yes_to_all) initialize_rootdir(Globals.root,opt.yes_to_all)
# Optimistically try to download all required assets. If any errors occur, add them and proceed anyway.
errors=set()
if opt.interactive: if opt.interactive:
print('** DOWNLOADING DIFFUSION WEIGHTS **') print('** DOWNLOADING DIFFUSION WEIGHTS **')
download_weights(opt) errors.add(download_weights(opt))
print('\n** DOWNLOADING SUPPORT MODELS **') print('\n** DOWNLOADING SUPPORT MODELS **')
download_bert() download_bert()
download_clip() download_clip()
@ -757,7 +772,7 @@ def main():
download_codeformer() download_codeformer()
download_clipseg() download_clipseg()
download_safety_checker() download_safety_checker()
postscript() postscript(errors=errors)
except KeyboardInterrupt: except KeyboardInterrupt:
print('\nGoodbye! Come back soon.') print('\nGoodbye! Come back soon.')
except Exception as e: except Exception as e: