further improvements to preload_models.py

- Faster startup for command line switch processing
- Specify configuration file to modify using --config option:

  ./scripts/preload_models.ply --config models/my-models-file.yaml
This commit is contained in:
Lincoln Stein 2022-10-31 11:33:05 -04:00
parent b5cf734ba9
commit 7159ec885f

View File

@ -6,31 +6,30 @@
# #
# Coauthor: Kevin Turner http://github.com/keturn # Coauthor: Kevin Turner http://github.com/keturn
# #
print('Loading Python libraries...\n')
import argparse import argparse
import clip
import sys import sys
import transformers
import os import os
import warnings import warnings
import torch
import zipfile
import traceback
import getpass
import requests
from urllib import request from urllib import request
from tqdm import tqdm from tqdm import tqdm
from omegaconf import OmegaConf from omegaconf import OmegaConf
from pathlib import Path from pathlib import Path
from transformers import CLIPTokenizer, CLIPTextModel import traceback
from transformers import BertTokenizerFast, AutoFeatureExtractor import getpass
from huggingface_hub import hf_hub_download, HfFolder, hf_hub_url import requests
transformers.logging.set_verbosity_error() # deferred loading so that help message can be printed quickly
def load_libs():
print('Loading Python libraries...\n')
import clip
import transformers
import torch
import zipfile
transformers.logging.set_verbosity_error()
#--------------------------globals-- #--------------------------globals--
Model_dir = './models/ldm/stable-diffusion-v1/' Model_dir = './models/ldm/stable-diffusion-v1/'
Config_file = './configs/models.yaml' Default_config_file = './configs/models.yaml'
SD_Configs = './configs/stable-diffusion' SD_Configs = './configs/stable-diffusion'
Datasets = { Datasets = {
'stable-diffusion-1.5': { 'stable-diffusion-1.5': {
@ -221,6 +220,8 @@ This involves a few easy steps.
''' '''
) )
input('Press <enter> when you are ready to continue:') input('Press <enter> when you are ready to continue:')
from huggingface_hub import HfFolder
access_token = HfFolder.get_token() access_token = HfFolder.get_token()
if access_token is None: if access_token is None:
print(''' print('''
@ -273,6 +274,7 @@ def download_weight_datasets(models:dict, access_token:str):
#--------------------------------------------- #---------------------------------------------
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool: def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
from huggingface_hub import hf_hub_url
model_dest = os.path.join(Model_dir, model_name) model_dest = os.path.join(Model_dir, model_name)
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
@ -320,8 +322,10 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
return True return True
#--------------------------------------------- #---------------------------------------------
def update_config_file(successfully_downloaded:dict): def update_config_file(successfully_downloaded:dict,opt:dict):
yaml = new_config_file_contents(successfully_downloaded) Config_file = opt.config_file or Default_config_file
yaml = new_config_file_contents(successfully_downloaded,Config_file)
try: try:
if os.path.exists(Config_file): if os.path.exists(Config_file):
@ -341,7 +345,7 @@ def update_config_file(successfully_downloaded:dict):
#--------------------------------------------- #---------------------------------------------
def new_config_file_contents(successfully_downloaded:dict)->str: def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str:
if os.path.exists(Config_file): if os.path.exists(Config_file):
conf = OmegaConf.load(Config_file) conf = OmegaConf.load(Config_file)
else: else:
@ -379,6 +383,7 @@ def new_config_file_contents(successfully_downloaded:dict)->str:
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
def download_bert(): def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='') print('Installing bert tokenizer (ignore deprecation errors)...', end='')
from transformers import BertTokenizerFast, AutoFeatureExtractor
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
@ -396,9 +401,10 @@ def download_kornia():
#--------------------------------------------- #---------------------------------------------
def download_clip(): def download_clip():
version = 'openai/clip-vit-large-patch14'
sys.stdout.flush()
print('Loading CLIP model...',end='') print('Loading CLIP model...',end='')
from transformers import CLIPTokenizer, CLIPTextModel
sys.stdout.flush()
version = 'openai/clip-vit-large-patch14'
tokenizer = CLIPTokenizer.from_pretrained(version) tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version) transformer = CLIPTextModel.from_pretrained(version)
print('...success') print('...success')
@ -498,6 +504,7 @@ def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='') print('Installing safety model for NSFW content detection...',end='')
try: try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
except ModuleNotFoundError: except ModuleNotFoundError:
print('Error installing safety checker model:') print('Error installing safety checker model:')
print(traceback.format_exc()) print(traceback.format_exc())
@ -515,7 +522,14 @@ if __name__ == '__main__':
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
default=True, default=True,
help='run in interactive mode (default)') help='run in interactive mode (default)')
parser.add_argument('--config_file',
'-c',
dest='config_file',
type=str,
default='./configs/models.yaml',
help='path to configuration file to create')
opt = parser.parse_args() opt = parser.parse_args()
load_libs()
try: try:
if opt.interactive: if opt.interactive:
@ -531,7 +545,7 @@ if __name__ == '__main__':
access_token = authenticate() access_token = authenticate()
print('\n** DOWNLOADING WEIGHTS **') print('\n** DOWNLOADING WEIGHTS **')
successfully_downloaded = download_weight_datasets(models, access_token) successfully_downloaded = download_weight_datasets(models, access_token)
update_config_file(successfully_downloaded) update_config_file(successfully_downloaded,opt)
print('\n** DOWNLOADING SUPPORT MODELS **') print('\n** DOWNLOADING SUPPORT MODELS **')
download_bert() download_bert()
download_kornia() download_kornia()