attempt to make batch install more reliable

1. added nvidia channel to environment.yml
2. updated pytorch-cuda requirement
3. let conda figure out what version of pytorch to install
4. add conda install status checking to .bat and .sh install files
5. in preload_models.py catch and handle download/access token errors
This commit is contained in:
Lincoln Stein 2022-11-01 12:02:22 -04:00
parent d7107d931a
commit ce298d32b5
4 changed files with 84 additions and 27 deletions

View File

@ -1,14 +1,14 @@
name: invokeai name: invokeai
channels: channels:
- pytorch - pytorch
- nvidia
- defaults - defaults
dependencies: dependencies:
- python>=3.9 - python>=3.9
- pip=20.3 - pip=22.2.2
- cudatoolkit=11.3 - numpy=1.23.3
- pytorch=1.11.0 - torchvision=0.14.0
- torchvision=0.12.0 - pytorch-cuda=11.7
- numpy=1.19.2
- pip: - pip:
- albumentations==0.4.3 - albumentations==0.4.3
- opencv-python==4.5.5.64 - opencv-python==4.5.5.64

View File

@ -80,14 +80,35 @@ if not exist ".git" (
call conda activate call conda activate
@rem create the environment @rem create the environment
call conda env remove -n invokeai
call conda env create call conda env create
call conda activate invokeai if "%ERRORLEVEL%" NE "0" (
echo ""
echo "Something went wrong while installing Python libraries and cannot continue.
echo "Please visit https://invoke-ai.github.io/InvokeAI/#installation for alternative"
echo "installation methods."
echo "Press any key to continue"
pause
exit /b
)
call conda activate invokeai
@rem preload the models @rem preload the models
call python scripts\preload_models.py call python scripts\preload_models.py
if "%ERRORLEVEL%" NE "0" (
echo ""
echo "The preload_models.py script crashed or was cancelled."
echo "InvokeAI is not ready to run. To run preload_models.py again,"
echo "run the command 'python scripts/preload_models.py'"
echo "Press any key to continue"
pause
exit /b
)
@rem tell the user their next steps @rem tell the user their next steps
echo. echo ""
echo "You can now start generating images by double-clicking the 'invoke.bat' file (inside this folder) echo "You can now start generating images by double-clicking the 'invoke.bat' file (inside this folder)
echo "Press any key to continue"
pause pause
exit 0

View File

@ -99,13 +99,29 @@ conda activate
if [ "$OS_NAME" == "mac" ]; then if [ "$OS_NAME" == "mac" ]; then
PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-${OS_ARCH} conda env create -f environment-mac.yml PIP_EXISTS_ACTION=w CONDA_SUBDIR=osx-${OS_ARCH} conda env create -f environment-mac.yml
else else
conda env remove -n invokeai
conda env create -f environment.yml conda env create -f environment.yml
fi fi
conda activate invokeai status = $?
# preload the models if test $status -ne 0
python scripts/preload_models.py then
echo "Something went wrong while installing Python libraries and cannot continue.
echo "Please visit https://invoke-ai.github.io/InvokeAI/#installation for alternative"
echo "installation methods"
else
conda activate invokeai
# tell the user their next steps # preload the models
echo "You can now start generating images by running invoke.sh (inside this folder), using ./invoke.sh" echo "Calling the preload_models.py script"
python scripts/preload_models.py
if test $? -ne 0
then
echo "The preload_models.py script crashed or was cancelled."
echo "InvokeAI is not ready to run. To run preload_models.py again,"
echo "give the command 'python scripts/preload_models.py'"
else
# tell the user their next steps
echo "You can now start generating images by running invoke.sh (inside this folder), using ./invoke.sh"
fi

View File

@ -18,6 +18,12 @@ import traceback
import getpass import getpass
import requests import requests
import warnings
warnings.filterwarnings('ignore')
#warnings.simplefilter('ignore')
#warnings.filterwarnings('ignore',category=DeprecationWarning)
#warnings.filterwarnings('ignore',category=UserWarning)
# deferred loading so that help message can be printed quickly # deferred loading so that help message can be printed quickly
def load_libs(): def load_libs():
print('Loading Python libraries...\n') print('Loading Python libraries...\n')
@ -238,7 +244,6 @@ This involves a few easy steps.
Now copy the token to your clipboard and paste it here: ''' Now copy the token to your clipboard and paste it here: '''
) )
access_token = getpass.getpass() access_token = getpass.getpass()
HfFolder.save_token(access_token)
return access_token return access_token
#--------------------------------------------- #---------------------------------------------
@ -256,6 +261,7 @@ def migrate_models_ckpt():
#--------------------------------------------- #---------------------------------------------
def download_weight_datasets(models:dict, access_token:str): def download_weight_datasets(models:dict, access_token:str):
from huggingface_hub import HfFolder
migrate_models_ckpt() migrate_models_ckpt()
successful = dict() successful = dict()
for mod in models.keys(): for mod in models.keys():
@ -268,6 +274,12 @@ def download_weight_datasets(models:dict, access_token:str):
) )
if success: if success:
successful[mod] = True successful[mod] = True
if len(successful) < len(models):
print(f'\n* There were errors downloading one or more files.')
print('Please double-check your license agreements, and your access token. Type ^C to quit.\n')
return None
HfFolder.save_token(access_token)
keys = ', '.join(successful.keys()) keys = ', '.join(successful.keys())
print(f'Successfully installed {keys}') print(f'Successfully installed {keys}')
return successful return successful
@ -295,6 +307,8 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
if resp.status_code==416: # "range not satisfiable", which means nothing to return if resp.status_code==416: # "range not satisfiable", which means nothing to return
print(f'* {model_name}: complete file found. Skipping.') print(f'* {model_name}: complete file found. Skipping.')
return True return True
elif resp.status_code != 200:
print(f'** An error occurred during downloading {model_name}: {resp.reason}')
elif exist_size > 0: elif exist_size > 0:
print(f'* {model_name}: partial file found. Resuming...') print(f'* {model_name}: partial file found. Resuming...')
else: else:
@ -302,7 +316,7 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
try: try:
if total < 2000: if total < 2000:
print(f'* {model_name}: {resp.text}') print(f'*** ERROR DOWNLOADING {model_name}: {resp.text}')
return False return False
with open(model_dest, open_mode) as file, tqdm( with open(model_dest, open_mode) as file, tqdm(
@ -329,7 +343,7 @@ def update_config_file(successfully_downloaded:dict,opt:dict):
try: try:
if os.path.exists(Config_file): if os.path.exists(Config_file):
print(f'** {Config_file} exists. Renaming to {Config_file}.orig') print(f'* {Config_file} exists. Renaming to {Config_file}.orig')
os.rename(Config_file,f'{Config_file}.orig') os.rename(Config_file,f'{Config_file}.orig')
tmpfile = os.path.join(os.path.dirname(Config_file),'new_config.tmp') tmpfile = os.path.join(os.path.dirname(Config_file),'new_config.tmp')
with open(tmpfile, 'w') as outfile: with open(tmpfile, 'w') as outfile:
@ -383,26 +397,27 @@ def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
def download_bert(): def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='') print('Installing bert tokenizer (ignore deprecation errors)...', end='')
from transformers import BertTokenizerFast, AutoFeatureExtractor sys.stdout.flush()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
from transformers import BertTokenizerFast, AutoFeatureExtractor
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success') print('...success')
sys.stdout.flush()
#--------------------------------------------- #---------------------------------------------
# this will download requirements for Kornia # this will download requirements for Kornia
def download_kornia(): def download_kornia():
print('Installing Kornia requirements...', end='') print('Installing Kornia requirements (ignore deprecation errors)...', end='')
with warnings.catch_warnings(): sys.stdout.flush()
warnings.filterwarnings('ignore', category=DeprecationWarning) import kornia
import kornia
print('...success') print('...success')
#--------------------------------------------- #---------------------------------------------
def download_clip(): def download_clip():
print('Loading CLIP model...',end='') print('Loading CLIP model...',end='')
from transformers import CLIPTokenizer, CLIPTextModel with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
from transformers import CLIPTokenizer, CLIPTextModel
sys.stdout.flush() sys.stdout.flush()
version = 'openai/clip-vit-large-patch14' version = 'openai/clip-vit-large-patch14'
tokenizer = CLIPTokenizer.from_pretrained(version) tokenizer = CLIPTokenizer.from_pretrained(version)
@ -541,11 +556,16 @@ if __name__ == '__main__':
if models is None: if models is None:
if yes_or_no('Quit?',default_yes=False): if yes_or_no('Quit?',default_yes=False):
sys.exit(0) sys.exit(0)
print('** LICENSE AGREEMENT FOR WEIGHT FILES **')
access_token = authenticate() done = False
print('\n** DOWNLOADING WEIGHTS **') while not done:
successfully_downloaded = download_weight_datasets(models, access_token) print('** LICENSE AGREEMENT FOR WEIGHT FILES **')
access_token = authenticate()
print('\n** DOWNLOADING WEIGHTS **')
successfully_downloaded = download_weight_datasets(models, access_token)
done = successfully_downloaded is not None
update_config_file(successfully_downloaded,opt) update_config_file(successfully_downloaded,opt)
print('\n** DOWNLOADING SUPPORT MODELS **') print('\n** DOWNLOADING SUPPORT MODELS **')
download_bert() download_bert()
download_kornia() download_kornia()