mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
attempt to make batch install more reliable
1. added nvidia channel to environment.yml 2. updated pytorch-cuda requirement 3. let conda figure out what version of pytorch to install 4. add conda install status checking to .bat and .sh install files 5. in preload_models.py catch and handle download/access token errors
This commit is contained in:
@ -18,6 +18,12 @@ import traceback
|
||||
import getpass
|
||||
import requests
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore')
|
||||
#warnings.simplefilter('ignore')
|
||||
#warnings.filterwarnings('ignore',category=DeprecationWarning)
|
||||
#warnings.filterwarnings('ignore',category=UserWarning)
|
||||
|
||||
# deferred loading so that help message can be printed quickly
|
||||
def load_libs():
|
||||
print('Loading Python libraries...\n')
|
||||
@ -238,7 +244,6 @@ This involves a few easy steps.
|
||||
Now copy the token to your clipboard and paste it here: '''
|
||||
)
|
||||
access_token = getpass.getpass()
|
||||
HfFolder.save_token(access_token)
|
||||
return access_token
|
||||
|
||||
#---------------------------------------------
|
||||
@ -256,6 +261,7 @@ def migrate_models_ckpt():
|
||||
|
||||
#---------------------------------------------
|
||||
def download_weight_datasets(models:dict, access_token:str):
|
||||
from huggingface_hub import HfFolder
|
||||
migrate_models_ckpt()
|
||||
successful = dict()
|
||||
for mod in models.keys():
|
||||
@ -268,6 +274,12 @@ def download_weight_datasets(models:dict, access_token:str):
|
||||
)
|
||||
if success:
|
||||
successful[mod] = True
|
||||
if len(successful) < len(models):
|
||||
print(f'\n* There were errors downloading one or more files.')
|
||||
print('Please double-check your license agreements, and your access token. Type ^C to quit.\n')
|
||||
return None
|
||||
|
||||
HfFolder.save_token(access_token)
|
||||
keys = ', '.join(successful.keys())
|
||||
print(f'Successfully installed {keys}')
|
||||
return successful
|
||||
@ -295,6 +307,8 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
||||
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
||||
print(f'* {model_name}: complete file found. Skipping.')
|
||||
return True
|
||||
elif resp.status_code != 200:
|
||||
print(f'** An error occurred during downloading {model_name}: {resp.reason}')
|
||||
elif exist_size > 0:
|
||||
print(f'* {model_name}: partial file found. Resuming...')
|
||||
else:
|
||||
@ -302,7 +316,7 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
||||
|
||||
try:
|
||||
if total < 2000:
|
||||
print(f'* {model_name}: {resp.text}')
|
||||
print(f'*** ERROR DOWNLOADING {model_name}: {resp.text}')
|
||||
return False
|
||||
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
@ -329,7 +343,7 @@ def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||
|
||||
try:
|
||||
if os.path.exists(Config_file):
|
||||
print(f'** {Config_file} exists. Renaming to {Config_file}.orig')
|
||||
print(f'* {Config_file} exists. Renaming to {Config_file}.orig')
|
||||
os.rename(Config_file,f'{Config_file}.orig')
|
||||
tmpfile = os.path.join(os.path.dirname(Config_file),'new_config.tmp')
|
||||
with open(tmpfile, 'w') as outfile:
|
||||
@ -383,26 +397,27 @@ def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||
sys.stdout.flush()
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
print('...success')
|
||||
sys.stdout.flush()
|
||||
|
||||
#---------------------------------------------
|
||||
# this will download requirements for Kornia
|
||||
def download_kornia():
|
||||
print('Installing Kornia requirements...', end='')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
import kornia
|
||||
print('Installing Kornia requirements (ignore deprecation errors)...', end='')
|
||||
sys.stdout.flush()
|
||||
import kornia
|
||||
print('...success')
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clip():
|
||||
print('Loading CLIP model...',end='')
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
sys.stdout.flush()
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
@ -541,11 +556,16 @@ if __name__ == '__main__':
|
||||
if models is None:
|
||||
if yes_or_no('Quit?',default_yes=False):
|
||||
sys.exit(0)
|
||||
print('** LICENSE AGREEMENT FOR WEIGHT FILES **')
|
||||
access_token = authenticate()
|
||||
print('\n** DOWNLOADING WEIGHTS **')
|
||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
||||
|
||||
done = False
|
||||
while not done:
|
||||
print('** LICENSE AGREEMENT FOR WEIGHT FILES **')
|
||||
access_token = authenticate()
|
||||
print('\n** DOWNLOADING WEIGHTS **')
|
||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
||||
done = successfully_downloaded is not None
|
||||
update_config_file(successfully_downloaded,opt)
|
||||
|
||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||
download_bert()
|
||||
download_kornia()
|
||||
|
Reference in New Issue
Block a user