mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
further improvements to preload_models.py
- Faster startup for command line switch processing - Specify configuration file to modify using --config option: ./scripts/preload_models.ply --config models/my-models-file.yaml
This commit is contained in:
parent
b5cf734ba9
commit
7159ec885f
@ -6,31 +6,30 @@
|
|||||||
#
|
#
|
||||||
# Coauthor: Kevin Turner http://github.com/keturn
|
# Coauthor: Kevin Turner http://github.com/keturn
|
||||||
#
|
#
|
||||||
print('Loading Python libraries...\n')
|
|
||||||
import argparse
|
import argparse
|
||||||
import clip
|
|
||||||
import sys
|
import sys
|
||||||
import transformers
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
import torch
|
|
||||||
import zipfile
|
|
||||||
import traceback
|
|
||||||
import getpass
|
|
||||||
import requests
|
|
||||||
from urllib import request
|
from urllib import request
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
import traceback
|
||||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
import getpass
|
||||||
from huggingface_hub import hf_hub_download, HfFolder, hf_hub_url
|
import requests
|
||||||
|
|
||||||
transformers.logging.set_verbosity_error()
|
# deferred loading so that help message can be printed quickly
|
||||||
|
def load_libs():
|
||||||
|
print('Loading Python libraries...\n')
|
||||||
|
import clip
|
||||||
|
import transformers
|
||||||
|
import torch
|
||||||
|
import zipfile
|
||||||
|
transformers.logging.set_verbosity_error()
|
||||||
|
|
||||||
#--------------------------globals--
|
#--------------------------globals--
|
||||||
Model_dir = './models/ldm/stable-diffusion-v1/'
|
Model_dir = './models/ldm/stable-diffusion-v1/'
|
||||||
Config_file = './configs/models.yaml'
|
Default_config_file = './configs/models.yaml'
|
||||||
SD_Configs = './configs/stable-diffusion'
|
SD_Configs = './configs/stable-diffusion'
|
||||||
Datasets = {
|
Datasets = {
|
||||||
'stable-diffusion-1.5': {
|
'stable-diffusion-1.5': {
|
||||||
@ -221,6 +220,8 @@ This involves a few easy steps.
|
|||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
input('Press <enter> when you are ready to continue:')
|
input('Press <enter> when you are ready to continue:')
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder
|
||||||
access_token = HfFolder.get_token()
|
access_token = HfFolder.get_token()
|
||||||
if access_token is None:
|
if access_token is None:
|
||||||
print('''
|
print('''
|
||||||
@ -273,6 +274,7 @@ def download_weight_datasets(models:dict, access_token:str):
|
|||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
||||||
|
from huggingface_hub import hf_hub_url
|
||||||
|
|
||||||
model_dest = os.path.join(Model_dir, model_name)
|
model_dest = os.path.join(Model_dir, model_name)
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
@ -320,8 +322,10 @@ def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def update_config_file(successfully_downloaded:dict):
|
def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||||
yaml = new_config_file_contents(successfully_downloaded)
|
Config_file = opt.config_file or Default_config_file
|
||||||
|
|
||||||
|
yaml = new_config_file_contents(successfully_downloaded,Config_file)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.path.exists(Config_file):
|
if os.path.exists(Config_file):
|
||||||
@ -341,7 +345,7 @@ def update_config_file(successfully_downloaded:dict):
|
|||||||
|
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def new_config_file_contents(successfully_downloaded:dict)->str:
|
def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str:
|
||||||
if os.path.exists(Config_file):
|
if os.path.exists(Config_file):
|
||||||
conf = OmegaConf.load(Config_file)
|
conf = OmegaConf.load(Config_file)
|
||||||
else:
|
else:
|
||||||
@ -379,6 +383,7 @@ def new_config_file_contents(successfully_downloaded:dict)->str:
|
|||||||
# this will preload the Bert tokenizer fles
|
# this will preload the Bert tokenizer fles
|
||||||
def download_bert():
|
def download_bert():
|
||||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
||||||
|
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||||
@ -396,9 +401,10 @@ def download_kornia():
|
|||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_clip():
|
def download_clip():
|
||||||
version = 'openai/clip-vit-large-patch14'
|
|
||||||
sys.stdout.flush()
|
|
||||||
print('Loading CLIP model...',end='')
|
print('Loading CLIP model...',end='')
|
||||||
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
sys.stdout.flush()
|
||||||
|
version = 'openai/clip-vit-large-patch14'
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||||
transformer = CLIPTextModel.from_pretrained(version)
|
transformer = CLIPTextModel.from_pretrained(version)
|
||||||
print('...success')
|
print('...success')
|
||||||
@ -498,6 +504,7 @@ def download_safety_checker():
|
|||||||
print('Installing safety model for NSFW content detection...',end='')
|
print('Installing safety model for NSFW content detection...',end='')
|
||||||
try:
|
try:
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
|
from transformers import AutoFeatureExtractor
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
print('Error installing safety checker model:')
|
print('Error installing safety checker model:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
@ -515,7 +522,14 @@ if __name__ == '__main__':
|
|||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help='run in interactive mode (default)')
|
help='run in interactive mode (default)')
|
||||||
|
parser.add_argument('--config_file',
|
||||||
|
'-c',
|
||||||
|
dest='config_file',
|
||||||
|
type=str,
|
||||||
|
default='./configs/models.yaml',
|
||||||
|
help='path to configuration file to create')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
load_libs()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if opt.interactive:
|
if opt.interactive:
|
||||||
@ -531,7 +545,7 @@ if __name__ == '__main__':
|
|||||||
access_token = authenticate()
|
access_token = authenticate()
|
||||||
print('\n** DOWNLOADING WEIGHTS **')
|
print('\n** DOWNLOADING WEIGHTS **')
|
||||||
successfully_downloaded = download_weight_datasets(models, access_token)
|
successfully_downloaded = download_weight_datasets(models, access_token)
|
||||||
update_config_file(successfully_downloaded)
|
update_config_file(successfully_downloaded,opt)
|
||||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||||
download_bert()
|
download_bert()
|
||||||
download_kornia()
|
download_kornia()
|
||||||
|
Loading…
Reference in New Issue
Block a user