mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
further improvements to preload_models script
- User can choose to download just recommended models, customize list to download, or skip downloading altogether. - Does direct download to models directory instead of to HuggingFace cache - Able to resume interrupted downloads
This commit is contained in:
parent
b532e6dd17
commit
19b6c671a6
@ -13,10 +13,13 @@ import transformers
|
||||
import os
|
||||
import warnings
|
||||
import torch
|
||||
import urllib.request
|
||||
import zipfile
|
||||
import traceback
|
||||
import getpass
|
||||
import requests
|
||||
|
||||
from urllib import request
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
@ -119,36 +122,56 @@ def yes_or_no(prompt:str, default_yes=True):
|
||||
return response[0] in ('y','Y')
|
||||
|
||||
#---------------------------------------------
|
||||
def user_wants_to_download_weights():
|
||||
def user_wants_to_download_weights()->str:
|
||||
'''
|
||||
Returns one of "skip", "recommended" or "customized"
|
||||
'''
|
||||
print('''You can download and configure the weights files manually or let this
|
||||
script do it for you. Manual installation is described at:
|
||||
|
||||
https://github.com/invoke-ai/InvokeAI/blob/main/docs/installation/INSTALLING_MODELS.md
|
||||
|
||||
You may download the recommended models (about 10GB total), select a customized set, or
|
||||
completely skip this step.
|
||||
'''
|
||||
)
|
||||
return yes_or_no('Would you like to download the Stable Diffusion model weights now?')
|
||||
selection = None
|
||||
while selection is None:
|
||||
choice = input('Download <r>ecommended models, <c>ustomize the list, or <s>kip this step? [r]: ')
|
||||
if choice.startswith(('r','R')) or len(choice)==0:
|
||||
selection = 'recommended'
|
||||
elif choice.startswith(('c','C')):
|
||||
selection = 'customized'
|
||||
elif choice.startswith(('s','S')):
|
||||
selection = 'skip'
|
||||
return selection
|
||||
|
||||
#---------------------------------------------
|
||||
def select_datasets():
|
||||
def select_datasets(action:str):
|
||||
done = False
|
||||
while not done:
|
||||
print('''
|
||||
datasets = dict()
|
||||
dflt = None # the first model selected will be the default; TODO let user change
|
||||
counter = 1
|
||||
|
||||
if action == 'customized':
|
||||
print('''
|
||||
Choose the weight file(s) you wish to download. Before downloading you
|
||||
will be given the option to view and change your selections.
|
||||
'''
|
||||
)
|
||||
datasets = dict()
|
||||
|
||||
counter = 1
|
||||
dflt = None # the first model selected will be the default; TODO let user change
|
||||
for ds in Datasets.keys():
|
||||
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
|
||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
|
||||
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
|
||||
datasets[ds]=counter
|
||||
counter += 1
|
||||
|
||||
for ds in Datasets.keys():
|
||||
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
|
||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
|
||||
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
|
||||
datasets[ds]=counter
|
||||
counter += 1
|
||||
else:
|
||||
for ds in Datasets.keys():
|
||||
if Datasets[ds]['recommended']:
|
||||
datasets[ds]=counter
|
||||
counter += 1
|
||||
|
||||
print('The following weight files will be downloaded:')
|
||||
for ds in datasets:
|
||||
dflt = '*' if dflt is None else ''
|
||||
@ -157,13 +180,15 @@ will be given the option to view and change your selections.
|
||||
ok_to_download = yes_or_no('Ok to download?')
|
||||
if not ok_to_download:
|
||||
if yes_or_no('Change your selection?'):
|
||||
action = 'customized'
|
||||
pass
|
||||
else:
|
||||
done = True
|
||||
else:
|
||||
done = True
|
||||
return datasets if ok_to_download else None
|
||||
|
||||
|
||||
|
||||
#-------------------------------Authenticate against Hugging Face
|
||||
def authenticate():
|
||||
print('''
|
||||
@ -180,13 +205,19 @@ This involves a few easy steps.
|
||||
You will need to verify your email address as part of the HuggingFace
|
||||
registration process.
|
||||
|
||||
2. Log into your account Hugging Face:
|
||||
2. Log into your Hugging Face account:
|
||||
|
||||
https://huggingface.co/login
|
||||
|
||||
3. Accept the license terms located here:
|
||||
|
||||
https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
|
||||
https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||
|
||||
and here:
|
||||
|
||||
https://huggingface.co/runwayml/stable-diffusion-inpainting
|
||||
|
||||
(Yes, you have to accept two slightly different license agreements)
|
||||
'''
|
||||
)
|
||||
input('Press <enter> when you are ready to continue:')
|
||||
@ -229,7 +260,7 @@ def download_weight_datasets(models:dict, access_token:str):
|
||||
for mod in models.keys():
|
||||
repo_id = Datasets[mod]['repo_id']
|
||||
filename = Datasets[mod]['file']
|
||||
success = conditional_download(
|
||||
success = download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_name=filename,
|
||||
access_token=access_token
|
||||
@ -241,19 +272,50 @@ def download_weight_datasets(models:dict, access_token:str):
|
||||
return successful
|
||||
|
||||
#---------------------------------------------
|
||||
def conditional_download(repo_id:str, model_name:str, access_token:str):
|
||||
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
||||
|
||||
model_dest = os.path.join(Model_dir, model_name)
|
||||
if os.path.exists(model_dest):
|
||||
print(f' * {model_name}: exists')
|
||||
return True
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
|
||||
header = {"Authorization": f'Bearer {access_token}'}
|
||||
open_mode = 'wb'
|
||||
exist_size = 0
|
||||
|
||||
if os.path.exists(model_dest):
|
||||
exist_size = os.path.getsize(model_dest)
|
||||
header['Range'] = f'bytes={exist_size}-'
|
||||
open_mode = 'ab'
|
||||
|
||||
resp = requests.get(url, headers=header, stream=True)
|
||||
total = int(resp.headers.get('content-length', 0))
|
||||
|
||||
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
||||
print(f'* {model_name}: complete file found. Skipping.')
|
||||
return True
|
||||
elif exist_size > 0:
|
||||
print(f'* {model_name}: partial file found. Resuming...')
|
||||
else:
|
||||
print(f'* {model_name}: Downloading...')
|
||||
|
||||
try:
|
||||
print(f' * {model_name}: downloading or retrieving from cache...')
|
||||
path = Path(hf_hub_download(repo_id, model_name, use_auth_token=access_token))
|
||||
path.resolve(strict=True).link_to(model_dest)
|
||||
if total < 2000:
|
||||
print(f'* {model_name}: {resp.text}')
|
||||
return False
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total+exist_size,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
except Exception as e:
|
||||
print(f'** Error downloading {model_name}: {str(e)} **')
|
||||
print(f'An error occurred while downloading {model_name}: {str(e)}')
|
||||
return False
|
||||
return True
|
||||
|
||||
@ -435,14 +497,15 @@ def download_safety_checker():
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
print('...success')
|
||||
|
||||
|
||||
#-------------------------------------
|
||||
if __name__ == '__main__':
|
||||
try:
|
||||
introduction()
|
||||
print('** WEIGHT SELECTION **')
|
||||
if user_wants_to_download_weights():
|
||||
models = select_datasets()
|
||||
choice = user_wants_to_download_weights()
|
||||
if choice != 'skip':
|
||||
models = select_datasets(choice)
|
||||
if models is None:
|
||||
if yes_or_no('Quit?',default_yes=False):
|
||||
sys.exit(0)
|
||||
|
Loading…
Reference in New Issue
Block a user