InvokeAI/scripts/preload_models.py

164 lines
6.3 KiB
Python
Raw Normal View History

2022-08-21 23:57:48 +00:00
#!/usr/bin/env python3
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
# Before running stable-diffusion on an internet-isolated machine,
# run this script from one with internet connectivity. The
# two machines must share a common .cache directory.
from transformers import CLIPTokenizer, CLIPTextModel
import clip
from transformers import BertTokenizerFast, AutoFeatureExtractor
import sys
import transformers
import os
import warnings
import torch
import urllib.request
import zipfile
import traceback
transformers.logging.set_verbosity_error()
#---------------------------------------------
# this will preload the Bert tokenizer fles
def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
print('...success')
sys.stdout.flush()
#---------------------------------------------
# this will download requirements for Kornia
def download_kornia():
print('Installing Kornia requirements...', end='')
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
import kornia
print('...success')
#---------------------------------------------
def download_clip():
version = 'openai/clip-vit-large-patch14'
sys.stdout.flush()
print('Loading CLIP model...',end='')
tokenizer = CLIPTokenizer.from_pretrained(version)
transformer = CLIPTextModel.from_pretrained(version)
print('...success')
#---------------------------------------------
def download_gfpgan():
print('Installing models from RealESRGAN and facexlib...',end='')
try:
from realesrgan import RealESRGANer
2022-09-25 22:11:59 +00:00
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(
scale=4,
2022-09-25 22:11:59 +00:00
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
)
FaceRestoreHelper(1, det_model='retinaface_resnet50')
print('...success')
except Exception:
print('Error loading ESRGAN:')
print(traceback.format_exc())
print('Loading models from GFPGAN')
import urllib.request
for model in (
[
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
'src/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth'
],
[
'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth',
'./gfpgan/weights/detection_Resnet50_Final.pth'
],
[
'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth',
'./gfpgan/weights/parsing_parsenet.pth'
],
):
model_url,model_dest = model
try:
if not os.path.exists(model_dest):
print(f'Downloading gfpgan model file {model_url}...',end='')
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
print('...success')
except Exception:
print('Error loading GFPGAN:')
print(traceback.format_exc())
#---------------------------------------------
def download_codeformer():
print('Installing CodeFormer model file...',end='')
try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
if not os.path.exists(model_dest):
print('Downloading codeformer model file...')
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
except Exception:
print('Error loading CodeFormer:')
print(traceback.format_exc())
print('...success')
#---------------------------------------------
def download_clipseg():
print('Installing clipseg model for text-based masking...',end='')
try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
model_dest = 'src/clipseg/clipseg_weights.zip'
weights_dir = 'src/clipseg/weights'
if not os.path.exists(weights_dir):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_url,model_dest)
with zipfile.ZipFile(model_dest,'r') as zip:
zip.extractall('src/clipseg')
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
os.remove(model_dest)
from clipseg_models.clipseg import CLIPDensePredT
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
model.eval()
model.load_state_dict(
torch.load(
'src/clipseg/weights/rd64-uni-refined.pth',
map_location=torch.device('cpu')
),
strict=False,
)
except Exception:
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')
#-------------------------------------
def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='')
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
except ModuleNotFoundError:
print('Error installing safety checker model:')
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
print('...success')
#-------------------------------------
if __name__ == '__main__':
download_bert()
download_kornia()
download_clip()
download_gfpgan()
download_codeformer()
download_clipseg()
download_safety_checker()