diff --git a/scripts/preload_models.py b/scripts/preload_models.py index bd737a32a3..9485985230 100644 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -13,10 +13,13 @@ import transformers import os import warnings import torch -import urllib.request import zipfile import traceback import getpass +import requests + +from urllib import request +from tqdm import tqdm from omegaconf import OmegaConf from pathlib import Path from transformers import CLIPTokenizer, CLIPTextModel @@ -119,36 +122,56 @@ def yes_or_no(prompt:str, default_yes=True): return response[0] in ('y','Y') #--------------------------------------------- -def user_wants_to_download_weights(): +def user_wants_to_download_weights()->str: + ''' + Returns one of "skip", "recommended" or "customized" + ''' print('''You can download and configure the weights files manually or let this script do it for you. Manual installation is described at: https://github.com/invoke-ai/InvokeAI/blob/main/docs/installation/INSTALLING_MODELS.md +You may download the recommended models (about 10GB total), select a customized set, or +completely skip this step. ''' ) - return yes_or_no('Would you like to download the Stable Diffusion model weights now?') + selection = None + while selection is None: + choice = input('Download ecommended models, ustomize the list, or kip this step? [r]: ') + if choice.startswith(('r','R')) or len(choice)==0: + selection = 'recommended' + elif choice.startswith(('c','C')): + selection = 'customized' + elif choice.startswith(('s','S')): + selection = 'skip' + return selection #--------------------------------------------- -def select_datasets(): +def select_datasets(action:str): done = False while not done: - print(''' + datasets = dict() + dflt = None # the first model selected will be the default; TODO let user change + counter = 1 + + if action == 'customized': + print(''' Choose the weight file(s) you wish to download. Before downloading you will be given the option to view and change your selections. ''' ) - datasets = dict() - - counter = 1 - dflt = None # the first model selected will be the default; TODO let user change - for ds in Datasets.keys(): - recommended = '(recommended)' if Datasets[ds]['recommended'] else '' - print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}') - if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']): - datasets[ds]=counter - counter += 1 - + for ds in Datasets.keys(): + recommended = '(recommended)' if Datasets[ds]['recommended'] else '' + print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}') + if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']): + datasets[ds]=counter + counter += 1 + else: + for ds in Datasets.keys(): + if Datasets[ds]['recommended']: + datasets[ds]=counter + counter += 1 + print('The following weight files will be downloaded:') for ds in datasets: dflt = '*' if dflt is None else '' @@ -157,13 +180,15 @@ will be given the option to view and change your selections. ok_to_download = yes_or_no('Ok to download?') if not ok_to_download: if yes_or_no('Change your selection?'): + action = 'customized' pass else: done = True else: done = True return datasets if ok_to_download else None - + + #-------------------------------Authenticate against Hugging Face def authenticate(): print(''' @@ -180,13 +205,19 @@ This involves a few easy steps. You will need to verify your email address as part of the HuggingFace registration process. -2. Log into your account Hugging Face: +2. Log into your Hugging Face account: https://huggingface.co/login 3. Accept the license terms located here: - https://huggingface.co/CompVis/stable-diffusion-v-1-4-original + https://huggingface.co/runwayml/stable-diffusion-v1-5 + + and here: + + https://huggingface.co/runwayml/stable-diffusion-inpainting + + (Yes, you have to accept two slightly different license agreements) ''' ) input('Press when you are ready to continue:') @@ -229,7 +260,7 @@ def download_weight_datasets(models:dict, access_token:str): for mod in models.keys(): repo_id = Datasets[mod]['repo_id'] filename = Datasets[mod]['file'] - success = conditional_download( + success = download_with_resume( repo_id=repo_id, model_name=filename, access_token=access_token @@ -241,19 +272,50 @@ def download_weight_datasets(models:dict, access_token:str): return successful #--------------------------------------------- -def conditional_download(repo_id:str, model_name:str, access_token:str): +def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool: + model_dest = os.path.join(Model_dir, model_name) - if os.path.exists(model_dest): - print(f' * {model_name}: exists') - return True os.makedirs(os.path.dirname(model_dest), exist_ok=True) + url = hf_hub_url(repo_id, model_name) + + header = {"Authorization": f'Bearer {access_token}'} + open_mode = 'wb' + exist_size = 0 + + if os.path.exists(model_dest): + exist_size = os.path.getsize(model_dest) + header['Range'] = f'bytes={exist_size}-' + open_mode = 'ab' + + resp = requests.get(url, headers=header, stream=True) + total = int(resp.headers.get('content-length', 0)) + + if resp.status_code==416: # "range not satisfiable", which means nothing to return + print(f'* {model_name}: complete file found. Skipping.') + return True + elif exist_size > 0: + print(f'* {model_name}: partial file found. Resuming...') + else: + print(f'* {model_name}: Downloading...') try: - print(f' * {model_name}: downloading or retrieving from cache...') - path = Path(hf_hub_download(repo_id, model_name, use_auth_token=access_token)) - path.resolve(strict=True).link_to(model_dest) + if total < 2000: + print(f'* {model_name}: {resp.text}') + return False + + with open(model_dest, open_mode) as file, tqdm( + desc=model_name, + initial=exist_size, + total=total+exist_size, + unit='iB', + unit_scale=True, + unit_divisor=1000, + ) as bar: + for data in resp.iter_content(chunk_size=1024): + size = file.write(data) + bar.update(size) except Exception as e: - print(f'** Error downloading {model_name}: {str(e)} **') + print(f'An error occurred while downloading {model_name}: {str(e)}') return False return True @@ -435,14 +497,15 @@ def download_safety_checker(): safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) print('...success') - + #------------------------------------- if __name__ == '__main__': try: introduction() print('** WEIGHT SELECTION **') - if user_wants_to_download_weights(): - models = select_datasets() + choice = user_wants_to_download_weights() + if choice != 'skip': + models = select_datasets(choice) if models is None: if yes_or_no('Quit?',default_yes=False): sys.exit(0)