further improvements to preload_models.py

- Faster startup for command line switch processing
- Specify configuration file to modify using --config option:

  ./scripts/preload_models.ply --config models/my-models-file.yaml
This commit is contained in:
Lincoln Stein 2022-10-31 11:33:05 -04:00
parent b5cf734ba9
commit 7159ec885f

View File

@ -6,31 +6,30 @@
#
# Coauthor: Kevin Turner http://github.com/keturn
#
print('Loading Python libraries...\n')
import argparse
import clip
import sys
import transformers
import os
import warnings
import torch
import zipfile
import traceback
import getpass
import requests
from urllib import request
from tqdm import tqdm
from omegaconf import OmegaConf
from pathlib import Path
from transformers import CLIPTokenizer, CLIPTextModel
from transformers import BertTokenizerFast, AutoFeatureExtractor
from huggingface_hub import hf_hub_download, HfFolder, hf_hub_url
import traceback
import getpass
import requests
transformers.logging.set_verbosity_error()
# deferred loading so that help message can be printed quickly
def load_libs():
print('Loading Python libraries...\n')
import clip
import transformers
import torch
import zipfile
transformers.logging.set_verbosity_error()
#--------------------------globals--
Model_dir = './models/ldm/stable-diffusion-v1/'
Config_file = './configs/models.yaml'
Default_config_file = './configs/models.yaml'
SD_Configs = './configs/stable-diffusion'
Datasets = {
'stable-diffusion-1.5': {
@ -221,6 +220,8 @@ This involves a few easy steps.
'''
)
input('Press <enter> when you are ready to continue:')
from huggingface_hub import HfFolder
access_token = HfFolder.get_token()
if access_token is None:
print('''
@ -273,6 +274,7 @@ def download_weight_datasets(models:dict, access_token:str):
#---------------------------------------------
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
from huggingface_hub import hf_hub_url
model_dest = os.path.join(Model_dir, model_name)
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
@ -320,8 +322,10 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
return True
#---------------------------------------------
def update_config_file(successfully_downloaded:dict):
yaml = new_config_file_contents(successfully_downloaded)
def update_config_file(successfully_downloaded:dict,opt:dict):
Config_file = opt.config_file or Default_config_file
yaml = new_config_file_contents(successfully_downloaded,Config_file)
try:
if os.path.exists(Config_file):
@ -341,7 +345,7 @@ def update_config_file(successfully_downloaded:dict):
#---------------------------------------------
def new_config_file_contents(successfully_downloaded:dict)->str:
def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str:
if os.path.exists(Config_file):
conf = OmegaConf.load(Config_file)
else:
@ -379,6 +383,7 @@ def new_config_file_contents(successfully_downloaded:dict)->str:
# this will preload the Bert tokenizer fles
def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
from transformers import BertTokenizerFast, AutoFeatureExtractor
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
@ -396,9 +401,10 @@ def download_kornia():
#---------------------------------------------
def download_clip():
version = 'openai/clip-vit-large-patch14'
sys.stdout.flush()
print('Loading CLIP model...',end='')
from transformers import CLIPTokenizer, CLIPTextModel
sys.stdout.flush()
version = 'openai/clip-vit-large-patch14'
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('...success')
@ -498,6 +504,7 @@ def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='')
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
except ModuleNotFoundError:
print('Error installing safety checker model:')
print(traceback.format_exc())
@ -515,7 +522,14 @@ if __name__ == '__main__':
action=argparse.BooleanOptionalAction,
default=True,
help='run in interactive mode (default)')
parser.add_argument('--config_file',
'-c',
dest='config_file',
type=str,
default='./configs/models.yaml',
help='path to configuration file to create')
opt = parser.parse_args()
load_libs()
try:
if opt.interactive:
@ -531,7 +545,7 @@ if __name__ == '__main__':
access_token = authenticate()
print('\n** DOWNLOADING WEIGHTS **')
successfully_downloaded = download_weight_datasets(models, access_token)
update_config_file(successfully_downloaded)
update_config_file(successfully_downloaded,opt)
print('\n** DOWNLOADING SUPPORT MODELS **')
download_bert()
download_kornia()