mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
further improvements to preload_models script
- User can choose to download just recommended models, customize list to download, or skip downloading altogether. - Does direct download to models directory instead of to HuggingFace cache - Able to resume interrupted downloads
This commit is contained in:
parent
b532e6dd17
commit
19b6c671a6
@ -13,10 +13,13 @@ import transformers
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
import torch
|
||||||
import urllib.request
|
|
||||||
import zipfile
|
import zipfile
|
||||||
import traceback
|
import traceback
|
||||||
import getpass
|
import getpass
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from urllib import request
|
||||||
|
from tqdm import tqdm
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
@ -119,35 +122,55 @@ def yes_or_no(prompt:str, default_yes=True):
|
|||||||
return response[0] in ('y','Y')
|
return response[0] in ('y','Y')
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def user_wants_to_download_weights():
|
def user_wants_to_download_weights()->str:
|
||||||
|
'''
|
||||||
|
Returns one of "skip", "recommended" or "customized"
|
||||||
|
'''
|
||||||
print('''You can download and configure the weights files manually or let this
|
print('''You can download and configure the weights files manually or let this
|
||||||
script do it for you. Manual installation is described at:
|
script do it for you. Manual installation is described at:
|
||||||
|
|
||||||
https://github.com/invoke-ai/InvokeAI/blob/main/docs/installation/INSTALLING_MODELS.md
|
https://github.com/invoke-ai/InvokeAI/blob/main/docs/installation/INSTALLING_MODELS.md
|
||||||
|
|
||||||
|
You may download the recommended models (about 10GB total), select a customized set, or
|
||||||
|
completely skip this step.
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
return yes_or_no('Would you like to download the Stable Diffusion model weights now?')
|
selection = None
|
||||||
|
while selection is None:
|
||||||
|
choice = input('Download <r>ecommended models, <c>ustomize the list, or <s>kip this step? [r]: ')
|
||||||
|
if choice.startswith(('r','R')) or len(choice)==0:
|
||||||
|
selection = 'recommended'
|
||||||
|
elif choice.startswith(('c','C')):
|
||||||
|
selection = 'customized'
|
||||||
|
elif choice.startswith(('s','S')):
|
||||||
|
selection = 'skip'
|
||||||
|
return selection
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def select_datasets():
|
def select_datasets(action:str):
|
||||||
done = False
|
done = False
|
||||||
while not done:
|
while not done:
|
||||||
|
datasets = dict()
|
||||||
|
dflt = None # the first model selected will be the default; TODO let user change
|
||||||
|
counter = 1
|
||||||
|
|
||||||
|
if action == 'customized':
|
||||||
print('''
|
print('''
|
||||||
Choose the weight file(s) you wish to download. Before downloading you
|
Choose the weight file(s) you wish to download. Before downloading you
|
||||||
will be given the option to view and change your selections.
|
will be given the option to view and change your selections.
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
datasets = dict()
|
|
||||||
|
|
||||||
counter = 1
|
|
||||||
dflt = None # the first model selected will be the default; TODO let user change
|
|
||||||
for ds in Datasets.keys():
|
for ds in Datasets.keys():
|
||||||
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
|
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
|
||||||
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
|
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
|
||||||
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
|
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
|
||||||
datasets[ds]=counter
|
datasets[ds]=counter
|
||||||
counter += 1
|
counter += 1
|
||||||
|
else:
|
||||||
|
for ds in Datasets.keys():
|
||||||
|
if Datasets[ds]['recommended']:
|
||||||
|
datasets[ds]=counter
|
||||||
|
counter += 1
|
||||||
|
|
||||||
print('The following weight files will be downloaded:')
|
print('The following weight files will be downloaded:')
|
||||||
for ds in datasets:
|
for ds in datasets:
|
||||||
@ -157,6 +180,7 @@ will be given the option to view and change your selections.
|
|||||||
ok_to_download = yes_or_no('Ok to download?')
|
ok_to_download = yes_or_no('Ok to download?')
|
||||||
if not ok_to_download:
|
if not ok_to_download:
|
||||||
if yes_or_no('Change your selection?'):
|
if yes_or_no('Change your selection?'):
|
||||||
|
action = 'customized'
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
done = True
|
done = True
|
||||||
@ -164,6 +188,7 @@ will be given the option to view and change your selections.
|
|||||||
done = True
|
done = True
|
||||||
return datasets if ok_to_download else None
|
return datasets if ok_to_download else None
|
||||||
|
|
||||||
|
|
||||||
#-------------------------------Authenticate against Hugging Face
|
#-------------------------------Authenticate against Hugging Face
|
||||||
def authenticate():
|
def authenticate():
|
||||||
print('''
|
print('''
|
||||||
@ -180,13 +205,19 @@ This involves a few easy steps.
|
|||||||
You will need to verify your email address as part of the HuggingFace
|
You will need to verify your email address as part of the HuggingFace
|
||||||
registration process.
|
registration process.
|
||||||
|
|
||||||
2. Log into your account Hugging Face:
|
2. Log into your Hugging Face account:
|
||||||
|
|
||||||
https://huggingface.co/login
|
https://huggingface.co/login
|
||||||
|
|
||||||
3. Accept the license terms located here:
|
3. Accept the license terms located here:
|
||||||
|
|
||||||
https://huggingface.co/CompVis/stable-diffusion-v-1-4-original
|
https://huggingface.co/runwayml/stable-diffusion-v1-5
|
||||||
|
|
||||||
|
and here:
|
||||||
|
|
||||||
|
https://huggingface.co/runwayml/stable-diffusion-inpainting
|
||||||
|
|
||||||
|
(Yes, you have to accept two slightly different license agreements)
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
input('Press <enter> when you are ready to continue:')
|
input('Press <enter> when you are ready to continue:')
|
||||||
@ -229,7 +260,7 @@ def download_weight_datasets(models:dict, access_token:str):
|
|||||||
for mod in models.keys():
|
for mod in models.keys():
|
||||||
repo_id = Datasets[mod]['repo_id']
|
repo_id = Datasets[mod]['repo_id']
|
||||||
filename = Datasets[mod]['file']
|
filename = Datasets[mod]['file']
|
||||||
success = conditional_download(
|
success = download_with_resume(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
model_name=filename,
|
model_name=filename,
|
||||||
access_token=access_token
|
access_token=access_token
|
||||||
@ -241,19 +272,50 @@ def download_weight_datasets(models:dict, access_token:str):
|
|||||||
return successful
|
return successful
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def conditional_download(repo_id:str, model_name:str, access_token:str):
|
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
||||||
|
|
||||||
model_dest = os.path.join(Model_dir, model_name)
|
model_dest = os.path.join(Model_dir, model_name)
|
||||||
if os.path.exists(model_dest):
|
|
||||||
print(f' * {model_name}: exists')
|
|
||||||
return True
|
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
|
url = hf_hub_url(repo_id, model_name)
|
||||||
|
|
||||||
|
header = {"Authorization": f'Bearer {access_token}'}
|
||||||
|
open_mode = 'wb'
|
||||||
|
exist_size = 0
|
||||||
|
|
||||||
|
if os.path.exists(model_dest):
|
||||||
|
exist_size = os.path.getsize(model_dest)
|
||||||
|
header['Range'] = f'bytes={exist_size}-'
|
||||||
|
open_mode = 'ab'
|
||||||
|
|
||||||
|
resp = requests.get(url, headers=header, stream=True)
|
||||||
|
total = int(resp.headers.get('content-length', 0))
|
||||||
|
|
||||||
|
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
||||||
|
print(f'* {model_name}: complete file found. Skipping.')
|
||||||
|
return True
|
||||||
|
elif exist_size > 0:
|
||||||
|
print(f'* {model_name}: partial file found. Resuming...')
|
||||||
|
else:
|
||||||
|
print(f'* {model_name}: Downloading...')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print(f' * {model_name}: downloading or retrieving from cache...')
|
if total < 2000:
|
||||||
path = Path(hf_hub_download(repo_id, model_name, use_auth_token=access_token))
|
print(f'* {model_name}: {resp.text}')
|
||||||
path.resolve(strict=True).link_to(model_dest)
|
return False
|
||||||
|
|
||||||
|
with open(model_dest, open_mode) as file, tqdm(
|
||||||
|
desc=model_name,
|
||||||
|
initial=exist_size,
|
||||||
|
total=total+exist_size,
|
||||||
|
unit='iB',
|
||||||
|
unit_scale=True,
|
||||||
|
unit_divisor=1000,
|
||||||
|
) as bar:
|
||||||
|
for data in resp.iter_content(chunk_size=1024):
|
||||||
|
size = file.write(data)
|
||||||
|
bar.update(size)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'** Error downloading {model_name}: {str(e)} **')
|
print(f'An error occurred while downloading {model_name}: {str(e)}')
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -441,8 +503,9 @@ if __name__ == '__main__':
|
|||||||
try:
|
try:
|
||||||
introduction()
|
introduction()
|
||||||
print('** WEIGHT SELECTION **')
|
print('** WEIGHT SELECTION **')
|
||||||
if user_wants_to_download_weights():
|
choice = user_wants_to_download_weights()
|
||||||
models = select_datasets()
|
if choice != 'skip':
|
||||||
|
models = select_datasets(choice)
|
||||||
if models is None:
|
if models is None:
|
||||||
if yes_or_no('Quit?',default_yes=False):
|
if yes_or_no('Quit?',default_yes=False):
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
Loading…
Reference in New Issue
Block a user