From 7159ec885faa1117a3506689d547618be1d91808 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Mon, 31 Oct 2022 11:33:05 -0400 Subject: [PATCH] further improvements to preload_models.py - Faster startup for command line switch processing - Specify configuration file to modify using --config option: ./scripts/preload_models.ply --config models/my-models-file.yaml --- scripts/preload_models.py | 52 +++++++++++++++++++++++++-------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/scripts/preload_models.py b/scripts/preload_models.py index 3fac4c6a0a..e5da9eca69 100644 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -6,31 +6,30 @@ # # Coauthor: Kevin Turner http://github.com/keturn # -print('Loading Python libraries...\n') import argparse -import clip import sys -import transformers import os import warnings -import torch -import zipfile -import traceback -import getpass -import requests from urllib import request from tqdm import tqdm from omegaconf import OmegaConf from pathlib import Path -from transformers import CLIPTokenizer, CLIPTextModel -from transformers import BertTokenizerFast, AutoFeatureExtractor -from huggingface_hub import hf_hub_download, HfFolder, hf_hub_url +import traceback +import getpass +import requests -transformers.logging.set_verbosity_error() +# deferred loading so that help message can be printed quickly +def load_libs(): + print('Loading Python libraries...\n') + import clip + import transformers + import torch + import zipfile + transformers.logging.set_verbosity_error() #--------------------------globals-- Model_dir = './models/ldm/stable-diffusion-v1/' -Config_file = './configs/models.yaml' +Default_config_file = './configs/models.yaml' SD_Configs = './configs/stable-diffusion' Datasets = { 'stable-diffusion-1.5': { @@ -221,6 +220,8 @@ This involves a few easy steps. ''' ) input('Press when you are ready to continue:') + + from huggingface_hub import HfFolder access_token = HfFolder.get_token() if access_token is None: print(''' @@ -273,6 +274,7 @@ def download_weight_datasets(models:dict, access_token:str): #--------------------------------------------- def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool: + from huggingface_hub import hf_hub_url model_dest = os.path.join(Model_dir, model_name) os.makedirs(os.path.dirname(model_dest), exist_ok=True) @@ -320,8 +322,10 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool: return True #--------------------------------------------- -def update_config_file(successfully_downloaded:dict): - yaml = new_config_file_contents(successfully_downloaded) +def update_config_file(successfully_downloaded:dict,opt:dict): + Config_file = opt.config_file or Default_config_file + + yaml = new_config_file_contents(successfully_downloaded,Config_file) try: if os.path.exists(Config_file): @@ -341,7 +345,7 @@ def update_config_file(successfully_downloaded:dict): #--------------------------------------------- -def new_config_file_contents(successfully_downloaded:dict)->str: +def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str: if os.path.exists(Config_file): conf = OmegaConf.load(Config_file) else: @@ -379,6 +383,7 @@ def new_config_file_contents(successfully_downloaded:dict)->str: # this will preload the Bert tokenizer fles def download_bert(): print('Installing bert tokenizer (ignore deprecation errors)...', end='') + from transformers import BertTokenizerFast, AutoFeatureExtractor with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') @@ -396,9 +401,10 @@ def download_kornia(): #--------------------------------------------- def download_clip(): - version = 'openai/clip-vit-large-patch14' - sys.stdout.flush() print('Loading CLIP model...',end='') + from transformers import CLIPTokenizer, CLIPTextModel + sys.stdout.flush() + version = 'openai/clip-vit-large-patch14' tokenizer = CLIPTokenizer.from_pretrained(version) transformer = CLIPTextModel.from_pretrained(version) print('...success') @@ -498,6 +504,7 @@ def download_safety_checker(): print('Installing safety model for NSFW content detection...',end='') try: from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + from transformers import AutoFeatureExtractor except ModuleNotFoundError: print('Error installing safety checker model:') print(traceback.format_exc()) @@ -515,7 +522,14 @@ if __name__ == '__main__': action=argparse.BooleanOptionalAction, default=True, help='run in interactive mode (default)') + parser.add_argument('--config_file', + '-c', + dest='config_file', + type=str, + default='./configs/models.yaml', + help='path to configuration file to create') opt = parser.parse_args() + load_libs() try: if opt.interactive: @@ -531,7 +545,7 @@ if __name__ == '__main__': access_token = authenticate() print('\n** DOWNLOADING WEIGHTS **') successfully_downloaded = download_weight_datasets(models, access_token) - update_config_file(successfully_downloaded) + update_config_file(successfully_downloaded,opt) print('\n** DOWNLOADING SUPPORT MODELS **') download_bert() download_kornia()