further improvements to preload_models script

- User can choose to download just recommended models, customize list to download,
  or skip downloading altogether.
- Does direct download to models directory instead of to HuggingFace cache
- Able to resume interrupted downloads
This commit is contained in:
Lincoln Stein 2022-10-30 00:17:05 -04:00
parent b532e6dd17
commit 19b6c671a6

View File

@ -13,10 +13,13 @@ import transformers
import os
import warnings
import torch
import urllib.request
import zipfile
import traceback
import getpass
import requests
from urllib import request
from tqdm import tqdm
from omegaconf import OmegaConf
from pathlib import Path
from transformers import CLIPTokenizer, CLIPTextModel
@ -119,36 +122,56 @@ def yes_or_no(prompt:str, default_yes=True):
return response[0] in ('y','Y')
#---------------------------------------------
def user_wants_to_download_weights():
def user_wants_to_download_weights()->str:
'''
Returns one of "skip", "recommended" or "customized"
'''
print('''You can download and configure the weights files manually or let this
script do it for you. Manual installation is described at:
https://github.com/invoke-ai/InvokeAI/blob/main/docs/installation/INSTALLING_MODELS.md
You may download the recommended models (about 10GB total), select a customized set, or
completely skip this step.
'''
)
return yes_or_no('Would you like to download the Stable Diffusion model weights now?')
selection = None
while selection is None:
choice = input('Download <r>ecommended models, <c>ustomize the list, or <s>kip this step? [r]: ')
if choice.startswith(('r','R')) or len(choice)==0:
selection = 'recommended'
elif choice.startswith(('c','C')):
selection = 'customized'
elif choice.startswith(('s','S')):
selection = 'skip'
return selection
#---------------------------------------------
def select_datasets():
def select_datasets(action:str):
done = False
while not done:
print('''
datasets = dict()
dflt = None # the first model selected will be the default; TODO let user change
counter = 1
if action == 'customized':
print('''
Choose the weight file(s) you wish to download. Before downloading you
will be given the option to view and change your selections.
'''
)
datasets = dict()
counter = 1
dflt = None # the first model selected will be the default; TODO let user change
for ds in Datasets.keys():
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
datasets[ds]=counter
counter += 1
for ds in Datasets.keys():
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
datasets[ds]=counter
counter += 1
else:
for ds in Datasets.keys():
if Datasets[ds]['recommended']:
datasets[ds]=counter
counter += 1
print('The following weight files will be downloaded:')
for ds in datasets:
dflt = '*' if dflt is None else ''
@ -157,13 +180,15 @@ will be given the option to view and change your selections.
ok_to_download = yes_or_no('Ok to download?')
if not ok_to_download:
if yes_or_no('Change your selection?'):
action = 'customized'
pass
else:
done = True
else:
done = True
return datasets if ok_to_download else None
#-------------------------------Authenticate against Hugging Face
def authenticate():
print('''
@ -180,13 +205,19 @@ This involves a few easy steps.
You will need to verify your email address as part of the HuggingFace
registration process.
2. Log into your account Hugging Face:
2. Log into your Hugging Face account:
https://huggingface.co/login
3. Accept the license terms located here:
https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
https://huggingface.co/runwayml/stable-diffusion-v1-5
and here:
https://huggingface.co/runwayml/stable-diffusion-inpainting
(Yes, you have to accept two slightly different license agreements)
'''
)
input('Press <enter> when you are ready to continue:')
@ -229,7 +260,7 @@ def download_weight_datasets(models:dict, access_token:str):
for mod in models.keys():
repo_id = Datasets[mod]['repo_id']
filename = Datasets[mod]['file']
success = conditional_download(
success = download_with_resume(
repo_id=repo_id,
model_name=filename,
access_token=access_token
@ -241,19 +272,50 @@ def download_weight_datasets(models:dict, access_token:str):
return successful
#---------------------------------------------
def conditional_download(repo_id:str, model_name:str, access_token:str):
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
model_dest = os.path.join(Model_dir, model_name)
if os.path.exists(model_dest):
print(f' * {model_name}: exists')
return True
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
url = hf_hub_url(repo_id, model_name)
header = {"Authorization": f'Bearer {access_token}'}
open_mode = 'wb'
exist_size = 0
if os.path.exists(model_dest):
exist_size = os.path.getsize(model_dest)
header['Range'] = f'bytes={exist_size}-'
open_mode = 'ab'
resp = requests.get(url, headers=header, stream=True)
total = int(resp.headers.get('content-length', 0))
if resp.status_code==416: # "range not satisfiable", which means nothing to return
print(f'* {model_name}: complete file found. Skipping.')
return True
elif exist_size > 0:
print(f'* {model_name}: partial file found. Resuming...')
else:
print(f'* {model_name}: Downloading...')
try:
print(f' * {model_name}: downloading or retrieving from cache...')
path = Path(hf_hub_download(repo_id, model_name, use_auth_token=access_token))
path.resolve(strict=True).link_to(model_dest)
if total < 2000:
print(f'* {model_name}: {resp.text}')
return False
with open(model_dest, open_mode) as file, tqdm(
desc=model_name,
initial=exist_size,
total=total+exist_size,
unit='iB',
unit_scale=True,
unit_divisor=1000,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
except Exception as e:
print(f'** Error downloading {model_name}: {str(e)} **')
print(f'An error occurred while downloading {model_name}: {str(e)}')
return False
return True
@ -435,14 +497,15 @@ def download_safety_checker():
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
print('...success')
#-------------------------------------
if __name__ == '__main__':
try:
introduction()
print('** WEIGHT SELECTION **')
if user_wants_to_download_weights():
models = select_datasets()
choice = user_wants_to_download_weights()
if choice != 'skip':
models = select_datasets(choice)
if models is None:
if yes_or_no('Quit?',default_yes=False):
sys.exit(0)