further improvements to preload_models script

- User can choose to download just recommended models, customize list to download,
  or skip downloading altogether.
- Does direct download to models directory instead of to HuggingFace cache
- Able to resume interrupted downloads
This commit is contained in:
Lincoln Stein 2022-10-30 00:17:05 -04:00
parent b532e6dd17
commit 19b6c671a6

View File

@ -13,10 +13,13 @@ import transformers
import os import os
import warnings import warnings
import torch import torch
import urllib.request
import zipfile import zipfile
import traceback import traceback
import getpass import getpass
import requests
from urllib import request
from tqdm import tqdm
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
@ -119,35 +122,55 @@ def yes_or_no(prompt:str, default_yes=True):
return response[0] in ('y','Y') return response[0] in ('y','Y')
#--------------------------------------------- #---------------------------------------------
def user_wants_to_download_weights(): def user_wants_to_download_weights()->str:
'''
Returns one of "skip", "recommended" or "customized"
'''
print('''You can download and configure the weights files manually or let this print('''You can download and configure the weights files manually or let this
script do it for you. Manual installation is described at: script do it for you. Manual installation is described at:
https://github.com/invoke-ai/InvokeAI/blob/main/docs/installation/INSTALLING_MODELS.md https://github.com/invoke-ai/InvokeAI/blob/main/docs/installation/INSTALLING_MODELS.md
You may download the recommended models (about 10GB total), select a customized set, or
completely skip this step.
''' '''
) )
return yes_or_no('Would you like to download the Stable Diffusion model weights now?') selection = None
while selection is None:
choice = input('Download <r>ecommended models, <c>ustomize the list, or <s>kip this step? [r]: ')
if choice.startswith(('r','R')) or len(choice)==0:
selection = 'recommended'
elif choice.startswith(('c','C')):
selection = 'customized'
elif choice.startswith(('s','S')):
selection = 'skip'
return selection
#--------------------------------------------- #---------------------------------------------
def select_datasets(): def select_datasets(action:str):
done = False done = False
while not done: while not done:
print(''' datasets = dict()
dflt = None # the first model selected will be the default; TODO let user change
counter = 1
if action == 'customized':
print('''
Choose the weight file(s) you wish to download. Before downloading you Choose the weight file(s) you wish to download. Before downloading you
will be given the option to view and change your selections. will be given the option to view and change your selections.
''' '''
) )
datasets = dict() for ds in Datasets.keys():
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
counter = 1 print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
dflt = None # the first model selected will be the default; TODO let user change if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
for ds in Datasets.keys(): datasets[ds]=counter
recommended = '(recommended)' if Datasets[ds]['recommended'] else '' counter += 1
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}') else:
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']): for ds in Datasets.keys():
datasets[ds]=counter if Datasets[ds]['recommended']:
counter += 1 datasets[ds]=counter
counter += 1
print('The following weight files will be downloaded:') print('The following weight files will be downloaded:')
for ds in datasets: for ds in datasets:
@ -157,6 +180,7 @@ will be given the option to view and change your selections.
ok_to_download = yes_or_no('Ok to download?') ok_to_download = yes_or_no('Ok to download?')
if not ok_to_download: if not ok_to_download:
if yes_or_no('Change your selection?'): if yes_or_no('Change your selection?'):
action = 'customized'
pass pass
else: else:
done = True done = True
@ -164,6 +188,7 @@ will be given the option to view and change your selections.
done = True done = True
return datasets if ok_to_download else None return datasets if ok_to_download else None
#-------------------------------Authenticate against Hugging Face #-------------------------------Authenticate against Hugging Face
def authenticate(): def authenticate():
print(''' print('''
@ -180,13 +205,19 @@ This involves a few easy steps.
You will need to verify your email address as part of the HuggingFace You will need to verify your email address as part of the HuggingFace
registration process. registration process.
2. Log into your account Hugging Face: 2. Log into your Hugging Face account:
https://huggingface.co/login https://huggingface.co/login
3. Accept the license terms located here: 3. Accept the license terms located here:
https://huggingface.co/CompVis/stable-diffusion-v-1-4-original https://huggingface.co/runwayml/stable-diffusion-v1-5
and here:
https://huggingface.co/runwayml/stable-diffusion-inpainting
(Yes, you have to accept two slightly different license agreements)
''' '''
) )
input('Press <enter> when you are ready to continue:') input('Press <enter> when you are ready to continue:')
@ -229,7 +260,7 @@ def download_weight_datasets(models:dict, access_token:str):
for mod in models.keys(): for mod in models.keys():
repo_id = Datasets[mod]['repo_id'] repo_id = Datasets[mod]['repo_id']
filename = Datasets[mod]['file'] filename = Datasets[mod]['file']
success = conditional_download( success = download_with_resume(
repo_id=repo_id, repo_id=repo_id,
model_name=filename, model_name=filename,
access_token=access_token access_token=access_token
@ -241,19 +272,50 @@ def download_weight_datasets(models:dict, access_token:str):
return successful return successful
#--------------------------------------------- #---------------------------------------------
def conditional_download(repo_id:str, model_name:str, access_token:str): def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
model_dest = os.path.join(Model_dir, model_name) model_dest = os.path.join(Model_dir, model_name)
if os.path.exists(model_dest):
print(f' * {model_name}: exists')
return True
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
url = hf_hub_url(repo_id, model_name)
header = {"Authorization": f'Bearer {access_token}'}
open_mode = 'wb'
exist_size = 0
if os.path.exists(model_dest):
exist_size = os.path.getsize(model_dest)
header['Range'] = f'bytes={exist_size}-'
open_mode = 'ab'
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get('content-length', 0))
if resp.status_code==416: # "range not satisfiable", which means nothing to return
print(f'* {model_name}: complete file found. Skipping.')
return True
elif exist_size > 0:
print(f'* {model_name}: partial file found. Resuming...')
else:
print(f'* {model_name}: Downloading...')
try: try:
print(f' * {model_name}: downloading or retrieving from cache...') if total < 2000:
path = Path(hf_hub_download(repo_id, model_name, use_auth_token=access_token)) print(f'* {model_name}: {resp.text}')
path.resolve(strict=True).link_to(model_dest) return False
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
total=total+exist_size,
unit='iB',
unit_scale=True,
unit_divisor=1000,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e: except Exception as e:
print(f'** Error downloading {model_name}: {str(e)} **') print(f'An error occurred while downloading {model_name}: {str(e)}')
return False return False
return True return True
@ -441,8 +503,9 @@ if __name__ == '__main__':
try: try:
introduction() introduction()
print('** WEIGHT SELECTION **') print('** WEIGHT SELECTION **')
if user_wants_to_download_weights(): choice = user_wants_to_download_weights()
models = select_datasets() if choice != 'skip':
models = select_datasets(choice)
if models is None: if models is None:
if yes_or_no('Quit?',default_yes=False): if yes_or_no('Quit?',default_yes=False):
sys.exit(0) sys.exit(0)