mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ef68a419f1
- user can select which weight files to download using huggingface cache - user must log in to huggingface, generate an access token, and accept license terms the very first time this is run. After that, everything works automatically. - added placeholder for docs for installing models - also got rid of unused config files. hopefully they weren't needed for textual inversion, but I don't think so.
412 lines
16 KiB
Python
412 lines
16 KiB
Python
#!/usr/bin/env python3
|
|
# Copyright (c) 2022 Lincoln D. Stein (https://github.com/lstein)
|
|
# Before running stable-diffusion on an internet-isolated machine,
|
|
# run this script from one with internet connectivity. The
|
|
# two machines must share a common .cache directory.
|
|
#
|
|
# Coauthor: Kevin Turner http://github.com/keturn
|
|
#
|
|
print('Loading Python libraries...\n')
|
|
import clip
|
|
import sys
|
|
import transformers
|
|
import os
|
|
import warnings
|
|
import torch
|
|
import urllib.request
|
|
import zipfile
|
|
import traceback
|
|
import getpass
|
|
from omegaconf import OmegaConf
|
|
from pathlib import Path
|
|
from transformers import CLIPTokenizer, CLIPTextModel
|
|
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
|
from huggingface_hub import hf_hub_download, HfFolder, hf_hub_url
|
|
|
|
transformers.logging.set_verbosity_error()
|
|
|
|
#--------------------------globals--
|
|
Model_dir = './models/ldm/stable-diffusion-v1/'
|
|
Config_file = './configs/models.yaml'
|
|
SD_Configs = './configs/stable-diffusion'
|
|
Datasets = {
|
|
'stable-diffusion-1.5': {
|
|
'description': 'The newest Stable Diffusion version 1.5 weight file (4.27 GB)',
|
|
'repo_id': 'runwayml/stable-diffusion-v1-5',
|
|
'config': 'v1-inference.yaml',
|
|
'file': 'v1-5-pruned-emaonly.ckpt',
|
|
'recommended': True,
|
|
'width': 512,
|
|
'height': 512,
|
|
},
|
|
'inpainting-1.5': {
|
|
'description': 'RunwayML SD 1.5 model optimized for inpainting (4.27 GB)',
|
|
'repo_id': 'runwayml/stable-diffusion-inpainting',
|
|
'config': 'v1-inpainting-inference.yaml',
|
|
'file': 'sd-v1-5-inpainting.ckpt',
|
|
'recommended': True,
|
|
'width': 512,
|
|
'height': 512,
|
|
},
|
|
'stable-diffusion-1.4': {
|
|
'description': 'The original Stable Diffusion version 1.4 weight file (4.27 GB)',
|
|
'repo_id': 'CompVis/stable-diffusion-v-1-4-original',
|
|
'config': 'v1-inference.yaml',
|
|
'file': 'sd-v1-4.ckpt',
|
|
'recommended': False,
|
|
'width': 512,
|
|
'height': 512,
|
|
},
|
|
'waifu-diffusion-1.3': {
|
|
'description': 'Stable Diffusion 1.4 fine tuned on anime-styled images (4.27)',
|
|
'repo_id': 'hakurei/waifu-diffusion-v1-3',
|
|
'config': 'v1-inference.yaml',
|
|
'file': 'model-epoch09-float32.ckpt',
|
|
'recommended': False,
|
|
'width': 512,
|
|
'height': 512,
|
|
},
|
|
'ft-mse-improved-autoencoder-840000': {
|
|
'description': 'StabilityAI improved autoencoder fine-tuned for human faces (recommended; 335 MB)',
|
|
'repo_id': 'stabilityai/sd-vae-ft-mse-original',
|
|
'config': 'VAE',
|
|
'file': 'vae-ft-mse-840000-ema-pruned.ckpt',
|
|
'recommended': True,
|
|
'width': 512,
|
|
'height': 512,
|
|
},
|
|
}
|
|
Config_preamble = '''# This file describes the alternative machine learning models
|
|
# available to InvokeAI script.
|
|
#
|
|
# To add a new model, follow the examples below. Each
|
|
# model requires a model config file, a weights file,
|
|
# and the width and height of the images it
|
|
# was trained on.
|
|
'''
|
|
|
|
#---------------------------------------------
|
|
def introduction():
|
|
print(
|
|
'''Welcome to InvokeAI. This script will help download the Stable Diffusion weight files
|
|
and other large models that are needed for text to image generation. At any point you may interrupt
|
|
this program and resume later.\n'''
|
|
)
|
|
|
|
#---------------------------------------------
|
|
def yes_or_no(prompt:str, default_yes=True):
|
|
default = "y" if default_yes else 'n'
|
|
response = input(f'{prompt} [{default}] ') or default
|
|
if default_yes:
|
|
return response[0] not in ('n','N')
|
|
else:
|
|
return response[0] in ('y','Y')
|
|
|
|
#---------------------------------------------
|
|
def user_wants_to_download_weights():
|
|
return yes_or_no('Would you like to download the Stable Diffusion model weights now?')
|
|
|
|
#---------------------------------------------
|
|
def select_datasets():
|
|
done = False
|
|
while not done:
|
|
print('''
|
|
Choose the weight file(s) you wish to download. Before downloading you
|
|
will be given the option to view and change your selections.
|
|
'''
|
|
)
|
|
datasets = dict()
|
|
|
|
counter = 1
|
|
dflt = None # the first model selected will be the default; TODO let user change
|
|
for ds in Datasets.keys():
|
|
recommended = '(recommended)' if Datasets[ds]['recommended'] else ''
|
|
print(f'[{counter}] {ds}:\n {Datasets[ds]["description"]} {recommended}')
|
|
if yes_or_no(' Download?',default_yes=Datasets[ds]['recommended']):
|
|
datasets[ds]=counter
|
|
counter += 1
|
|
|
|
print('The following weight files will be downloaded:')
|
|
for ds in datasets:
|
|
dflt = '*' if dflt is None else ''
|
|
print(f' [{datasets[ds]}] {ds}{dflt}')
|
|
print("*default")
|
|
ok_to_download = yes_or_no('Ok to download?')
|
|
if not ok_to_download:
|
|
if yes_or_no('Change your selection?'):
|
|
pass
|
|
else:
|
|
done = True
|
|
else:
|
|
done = True
|
|
return datasets if ok_to_download else None
|
|
|
|
#-------------------------------Authenticate against Hugging Face
|
|
def authenticate():
|
|
print('''
|
|
To download the Stable Diffusion weight files you need to read and accept the
|
|
CreativeML Responsible AI license. If you have not already done so, please
|
|
create an account at https://huggingface.co. Then login under your account and
|
|
read and accept the license available at https://huggingface.co/CompVis/stable-diffusion-v-1-4-original.
|
|
'''
|
|
)
|
|
input('Press <enter> when you are ready to continue:')
|
|
access_token = HfFolder.get_token()
|
|
if access_token is None:
|
|
print('''
|
|
Thank you! Now you need to authenticate with your HuggingFace access token.
|
|
Go to https://huggingface.co/settings/tokens and create a token. Copy it to the
|
|
clipboard and paste it here: '''
|
|
)
|
|
access_token = getpass.getpass()
|
|
HfFolder.save_token(access_token)
|
|
return access_token
|
|
|
|
#---------------------------------------------
|
|
# look for legacy model.ckpt in models directory and offer to
|
|
# normalize its name
|
|
def migrate_models_ckpt():
|
|
if not os.path.exists(os.path.join(Model_dir,'model.ckpt')):
|
|
return
|
|
new_name = Datasets['stable-diffusion-1.4']['file']
|
|
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
|
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
|
if rename:
|
|
print(f'model.ckpt => {new_name}')
|
|
os.rename(os.path.join(Model_dir,'model.ckpt'),os.path.join(Model_dir,new_name))
|
|
|
|
#---------------------------------------------
|
|
def download_weight_datasets(models:dict, access_token:str):
|
|
migrate_models_ckpt()
|
|
successful = dict()
|
|
for mod in models.keys():
|
|
repo_id = Datasets[mod]['repo_id']
|
|
filename = Datasets[mod]['file']
|
|
success = conditional_download(
|
|
repo_id=repo_id,
|
|
model_name=filename,
|
|
access_token=access_token
|
|
)
|
|
if success:
|
|
successful[mod] = True
|
|
keys = ', '.join(successful.keys())
|
|
print(f'Successfully installed {keys}')
|
|
return successful
|
|
|
|
#---------------------------------------------
|
|
def conditional_download(repo_id:str, model_name:str, access_token:str):
|
|
model_dest = os.path.join(Model_dir, model_name)
|
|
if os.path.exists(model_dest):
|
|
print(f' * {model_name}: exists')
|
|
return True
|
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
|
|
|
try:
|
|
print(f' * {model_name}: downloading or retrieving from cache...')
|
|
path = Path(hf_hub_download(repo_id, model_name, use_auth_token=access_token))
|
|
path.resolve(strict=True).link_to(model_dest)
|
|
except Exception as e:
|
|
print(f'** Error downloading {model_name}: {str(e)} **')
|
|
return False
|
|
return True
|
|
|
|
#---------------------------------------------
|
|
def update_config_file(successfully_downloaded:dict):
|
|
try:
|
|
yaml = new_config_file_contents(successfully_downloaded)
|
|
tmpfile = os.path.join(os.path.dirname(Config_file),'new_config.tmp')
|
|
with open(tmpfile, 'w') as outfile:
|
|
outfile.write(Config_preamble)
|
|
outfile.write(yaml)
|
|
os.rename(tmpfile,Config_file)
|
|
except Exception as e:
|
|
print(f'**Error creating config file {Config_file}: {str(e)} **')
|
|
return
|
|
print(f'Successfully created new configuration file {Config_file}')
|
|
|
|
|
|
#---------------------------------------------
|
|
def new_config_file_contents(successfully_downloaded:dict)->str:
|
|
conf = OmegaConf.load(Config_file)
|
|
|
|
# find the VAE file, if there is one
|
|
vae = None
|
|
default_selected = False
|
|
|
|
for model in successfully_downloaded:
|
|
if Datasets[model]['config'] == 'VAE':
|
|
vae = Datasets[model]['file']
|
|
|
|
for model in successfully_downloaded:
|
|
if Datasets[model]['config'] == 'VAE': # skip VAE entries
|
|
continue
|
|
stanza = conf[model] if model in conf else { }
|
|
|
|
stanza['description'] = Datasets[model]['description']
|
|
stanza['weights'] = os.path.join(Model_dir,Datasets[model]['file'])
|
|
stanza['config'] =os.path.join(SD_Configs, Datasets[model]['config'])
|
|
stanza['width'] = Datasets[model]['width']
|
|
stanza['height'] = Datasets[model]['height']
|
|
stanza.pop('default',None) # this will be set later
|
|
if vae:
|
|
stanza['vae'] = os.path.join(Model_dir,vae)
|
|
# BUG - the first stanza is always the default. User should select.
|
|
if not default_selected:
|
|
stanza['default'] = True
|
|
default_selected = True
|
|
conf[model] = stanza
|
|
return OmegaConf.to_yaml(conf)
|
|
|
|
#---------------------------------------------
|
|
# this will preload the Bert tokenizer fles
|
|
def download_bert():
|
|
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
|
print('...success')
|
|
sys.stdout.flush()
|
|
|
|
#---------------------------------------------
|
|
# this will download requirements for Kornia
|
|
def download_kornia():
|
|
print('Installing Kornia requirements...', end='')
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
import kornia
|
|
print('...success')
|
|
|
|
#---------------------------------------------
|
|
def download_clip():
|
|
version = 'openai/clip-vit-large-patch14'
|
|
sys.stdout.flush()
|
|
print('Loading CLIP model...',end='')
|
|
tokenizer = CLIPTokenizer.from_pretrained(version)
|
|
transformer = CLIPTextModel.from_pretrained(version)
|
|
print('...success')
|
|
|
|
#---------------------------------------------
|
|
def download_gfpgan():
|
|
print('Installing models from RealESRGAN and facexlib...',end='')
|
|
try:
|
|
from realesrgan import RealESRGANer
|
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
|
|
|
RealESRGANer(
|
|
scale=4,
|
|
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
|
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
|
)
|
|
|
|
FaceRestoreHelper(1, det_model='retinaface_resnet50')
|
|
print('...success')
|
|
except Exception:
|
|
print('Error loading ESRGAN:')
|
|
print(traceback.format_exc())
|
|
|
|
print('Loading models from GFPGAN')
|
|
for model in (
|
|
[
|
|
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
|
|
'src/gfpgan/experiments/pretrained_models/GFPGANv1.4.pth'
|
|
],
|
|
[
|
|
'https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth',
|
|
'./gfpgan/weights/detection_Resnet50_Final.pth'
|
|
],
|
|
[
|
|
'https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth',
|
|
'./gfpgan/weights/parsing_parsenet.pth'
|
|
],
|
|
):
|
|
model_url,model_dest = model
|
|
try:
|
|
if not os.path.exists(model_dest):
|
|
print(f'Downloading gfpgan model file {model_url}...',end='')
|
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
|
urllib.request.urlretrieve(model_url,model_dest)
|
|
print('...success')
|
|
except Exception:
|
|
print('Error loading GFPGAN:')
|
|
print(traceback.format_exc())
|
|
|
|
#---------------------------------------------
|
|
def download_codeformer():
|
|
print('Installing CodeFormer model file...',end='')
|
|
try:
|
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
|
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
|
if not os.path.exists(model_dest):
|
|
print('Downloading codeformer model file...')
|
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
|
urllib.request.urlretrieve(model_url,model_dest)
|
|
except Exception:
|
|
print('Error loading CodeFormer:')
|
|
print(traceback.format_exc())
|
|
print('...success')
|
|
|
|
#---------------------------------------------
|
|
def download_clipseg():
|
|
print('Installing clipseg model for text-based masking...',end='')
|
|
try:
|
|
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
|
model_dest = 'src/clipseg/clipseg_weights.zip'
|
|
weights_dir = 'src/clipseg/weights'
|
|
if not os.path.exists(weights_dir):
|
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
|
urllib.request.urlretrieve(model_url,model_dest)
|
|
with zipfile.ZipFile(model_dest,'r') as zip:
|
|
zip.extractall('src/clipseg')
|
|
os.rename('src/clipseg/clipseg_weights','src/clipseg/weights')
|
|
os.remove(model_dest)
|
|
from clipseg_models.clipseg import CLIPDensePredT
|
|
model = CLIPDensePredT(version='ViT-B/16', reduce_dim=64, )
|
|
model.eval()
|
|
model.load_state_dict(
|
|
torch.load(
|
|
'src/clipseg/weights/rd64-uni-refined.pth',
|
|
map_location=torch.device('cpu')
|
|
),
|
|
strict=False,
|
|
)
|
|
except Exception:
|
|
print('Error installing clipseg model:')
|
|
print(traceback.format_exc())
|
|
print('...success')
|
|
|
|
#-------------------------------------
|
|
def download_safety_checker():
|
|
print('Installing safety model for NSFW content detection...',end='')
|
|
try:
|
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
except ModuleNotFoundError:
|
|
print('Error installing safety checker model:')
|
|
print(traceback.format_exc())
|
|
return
|
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
|
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
|
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
|
print('...success')
|
|
|
|
#-------------------------------------
|
|
if __name__ == '__main__':
|
|
introduction()
|
|
if user_wants_to_download_weights():
|
|
models = select_datasets()
|
|
if models is None:
|
|
if yes_or_no('Quit?',default_yes=False):
|
|
sys.exit(0)
|
|
access_token = authenticate()
|
|
successfully_downloaded = download_weight_datasets(models, access_token)
|
|
update_config_file(successfully_downloaded)
|
|
download_bert()
|
|
download_kornia()
|
|
download_clip()
|
|
download_gfpgan()
|
|
download_codeformer()
|
|
download_clipseg()
|
|
download_safety_checker()
|
|
|
|
|