mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
further improvements to preload_models.py
- Faster startup for command line switch processing - Specify configuration file to modify using --config option: ./scripts/preload_models.ply --config models/my-models-file.yaml
This commit is contained in:
parent
b5cf734ba9
commit
7159ec885f
@ -6,31 +6,30 @@
|
||||
#
|
||||
# Coauthor: Kevin Turner http://github.com/keturn
|
||||
#
|
||||
print('Loading Python libraries...\n')
|
||||
import argparse
|
||||
import clip
|
||||
import sys
|
||||
import transformers
|
||||
import os
|
||||
import warnings
|
||||
import torch
|
||||
import zipfile
|
||||
import traceback
|
||||
import getpass
|
||||
import requests
|
||||
from urllib import request
|
||||
from tqdm import tqdm
|
||||
from omegaconf import OmegaConf
|
||||
from pathlib import Path
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||
from huggingface_hub import hf_hub_download, HfFolder, hf_hub_url
|
||||
import traceback
|
||||
import getpass
|
||||
import requests
|
||||
|
||||
transformers.logging.set_verbosity_error()
|
||||
# deferred loading so that help message can be printed quickly
|
||||
def load_libs():
|
||||
print('Loading Python libraries...\n')
|
||||
import clip
|
||||
import transformers
|
||||
import torch
|
||||
import zipfile
|
||||
transformers.logging.set_verbosity_error()
|
||||
|
||||
#--------------------------globals--
|
||||
Model_dir = './models/ldm/stable-diffusion-v1/'
|
||||
Config_file = './configs/models.yaml'
|
||||
Default_config_file = './configs/models.yaml'
|
||||
SD_Configs = './configs/stable-diffusion'
|
||||
Datasets = {
|
||||
'stable-diffusion-1.5': {
|
||||
@ -221,6 +220,8 @@ This involves a few easy steps.
|
||||
'''
|
||||
)
|
||||
input('Press <enter> when you are ready to continue:')
|
||||
|
||||
from huggingface_hub import HfFolder
|
||||
access_token = HfFolder.get_token()
|
||||
if access_token is None:
|
||||
print('''
|
||||
@ -273,6 +274,7 @@ def download_weight_datasets(models:dict, access_token:str):
|
||||
|
||||
#---------------------------------------------
|
||||
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
||||
from huggingface_hub import hf_hub_url
|
||||
|
||||
model_dest = os.path.join(Model_dir, model_name)
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
@ -320,8 +322,10 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
||||
return True
|
||||
|
||||
#---------------------------------------------
|
||||
def update_config_file(successfully_downloaded:dict):
|
||||
yaml = new_config_file_contents(successfully_downloaded)
|
||||
def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||
Config_file = opt.config_file or Default_config_file
|
||||
|
||||
yaml = new_config_file_contents(successfully_downloaded,Config_file)
|
||||
|
||||
try:
|
||||
if os.path.exists(Config_file):
|
||||
@ -341,7 +345,7 @@ def update_config_file(successfully_downloaded:dict):
|
||||
|
||||
|
||||
#---------------------------------------------
|
||||
def new_config_file_contents(successfully_downloaded:dict)->str:
|
||||
def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str:
|
||||
if os.path.exists(Config_file):
|
||||
conf = OmegaConf.load(Config_file)
|
||||
else:
|
||||
@ -379,6 +383,7 @@ def new_config_file_contents(successfully_downloaded:dict)->str:
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
@ -396,9 +401,10 @@ def download_kornia():
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clip():
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
sys.stdout.flush()
|
||||
print('Loading CLIP model...',end='')
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
sys.stdout.flush()
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
transformer = CLIPTextModel.from_pretrained(version)
|
||||
print('...success')
|
||||
@ -498,6 +504,7 @@ def download_safety_checker():
|
||||
print('Installing safety model for NSFW content detection...',end='')
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
except ModuleNotFoundError:
|
||||
print('Error installing safety checker model:')
|
||||
print(traceback.format_exc())
|
||||
@ -515,7 +522,14 @@ if __name__ == '__main__':
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='run in interactive mode (default)')
|
||||
parser.add_argument('--config_file',
|
||||
'-c',
|
||||
dest='config_file',
|
||||
type=str,
|
||||
default='./configs/models.yaml',
|
||||
help='path to configuration file to create')
|
||||
opt = parser.parse_args()
|
||||
load_libs()
|
||||
|
||||
try:
|
||||
if opt.interactive:
|
||||
@ -531,7 +545,7 @@ if __name__ == '__main__':
|
||||
access_token = authenticate()
|
||||
print('\n** DOWNLOADING WEIGHTS **')
|
||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
||||
update_config_file(successfully_downloaded)
|
||||
update_config_file(successfully_downloaded,opt)
|
||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||
download_bert()
|
||||
download_kornia()
|
||||
|
Loading…
Reference in New Issue
Block a user