(config) try to authenticate to Huggingface more eagerly, using env vars

This commit is contained in:
Eugene Brodsky 2022-11-27 05:23:42 -05:00 committed by Lincoln Stein
parent f237744ab1
commit ed6194351c

View File

@ -212,8 +212,7 @@ This involves a few easy steps.
access_token = HfFolder.get_token() access_token = HfFolder.get_token()
if access_token is not None: if access_token is not None:
print('found') print('found')
else:
if access_token is None:
print('not found') print('not found')
print(''' print('''
4. Thank you! The last step is to enter your HuggingFace access token so that 4. Thank you! The last step is to enter your HuggingFace access token so that
@ -231,6 +230,7 @@ This involves a few easy steps.
Token: ''' Token: '''
) )
access_token = getpass_asterisk.getpass_asterisk() access_token = getpass_asterisk.getpass_asterisk()
HfFolder.save_token(access_token)
return access_token return access_token
#--------------------------------------------- #---------------------------------------------
@ -521,10 +521,16 @@ def download_safety_checker():
print('...success',file=sys.stderr) print('...success',file=sys.stderr)
#------------------------------------- #-------------------------------------
def download_weights(opt:dict): # Authenticate to Huggingface using environment variables.
# If successful, authentication will persist for either interactive or non-interactive use.
# Default env var expected by HuggingFace is HUGGING_FACE_HUB_TOKEN.
if not (access_token := HfFolder.get_token()):
# If unable to find an existing token or expected environment, try the non-canonical environment variable (widely used in the community and supported as per docs)
if (access_token := os.getenv("HUGGINGFACE_TOKEN")):
HfFolder.save_token(access_token)
if opt.yes_to_all: if opt.yes_to_all:
models = recommended_datasets() models = recommended_datasets()
access_token = HfFolder.get_token()
if len(models)>0 and access_token is not None: if len(models)>0 and access_token is not None:
successfully_downloaded = download_weight_datasets(models, access_token) successfully_downloaded = download_weight_datasets(models, access_token)
update_config_file(successfully_downloaded,opt) update_config_file(successfully_downloaded,opt)
@ -547,6 +553,7 @@ def download_weights(opt:dict):
return return
print('** LICENSE AGREEMENT FOR WEIGHT FILES **') print('** LICENSE AGREEMENT FOR WEIGHT FILES **')
# We are either already authenticated, or will be asked to provide the token interactively
access_token = authenticate() access_token = authenticate()
print('\n** DOWNLOADING WEIGHTS **') print('\n** DOWNLOADING WEIGHTS **')
successfully_downloaded = download_weight_datasets(models, access_token) successfully_downloaded = download_weight_datasets(models, access_token)