mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
configure: try to download models even without token
Models in the CompVis and stabilityai repos no longer require them. (But runwayml still does.)
This commit is contained in:
parent
c21660a6df
commit
8ce1ae550b
@ -8,26 +8,28 @@
|
|||||||
#
|
#
|
||||||
print('Loading Python libraries...\n')
|
print('Loading Python libraries...\n')
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import warnings
|
|
||||||
import shutil
|
import shutil
|
||||||
from urllib import request
|
import sys
|
||||||
from tqdm import tqdm
|
import traceback
|
||||||
from omegaconf import OmegaConf
|
import warnings
|
||||||
from huggingface_hub import HfFolder, hf_hub_url
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import transformers
|
||||||
|
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
||||||
from getpass_asterisk import getpass_asterisk
|
from getpass_asterisk import getpass_asterisk
|
||||||
|
from huggingface_hub import HfFolder, hf_hub_url, whoami as hf_whoami
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.invoke.readline import generic_completer
|
from ldm.invoke.readline import generic_completer
|
||||||
|
|
||||||
import traceback
|
|
||||||
import requests
|
|
||||||
import clip
|
|
||||||
import transformers
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import torch
|
import torch
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -65,7 +67,7 @@ this program and resume later.\n'''
|
|||||||
def postscript():
|
def postscript():
|
||||||
print(
|
print(
|
||||||
'''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands:
|
'''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands:
|
||||||
Web version:
|
Web version:
|
||||||
python scripts/invoke.py --web (connect to http://localhost:9090)
|
python scripts/invoke.py --web (connect to http://localhost:9090)
|
||||||
Command-line version:
|
Command-line version:
|
||||||
python scripts/invoke.py
|
python scripts/invoke.py
|
||||||
@ -127,7 +129,7 @@ def select_datasets(action:str):
|
|||||||
|
|
||||||
if action == 'customized':
|
if action == 'customized':
|
||||||
print('''
|
print('''
|
||||||
Choose the weight file(s) you wish to download. Before downloading you
|
Choose the weight file(s) you wish to download. Before downloading you
|
||||||
will be given the option to view and change your selections.
|
will be given the option to view and change your selections.
|
||||||
'''
|
'''
|
||||||
)
|
)
|
||||||
@ -142,7 +144,7 @@ will be given the option to view and change your selections.
|
|||||||
if Datasets[ds]['recommended']:
|
if Datasets[ds]['recommended']:
|
||||||
datasets[ds]=counter
|
datasets[ds]=counter
|
||||||
counter += 1
|
counter += 1
|
||||||
|
|
||||||
print('The following weight files will be downloaded:')
|
print('The following weight files will be downloaded:')
|
||||||
for ds in datasets:
|
for ds in datasets:
|
||||||
dflt = '*' if dflt is None else ''
|
dflt = '*' if dflt is None else ''
|
||||||
@ -166,11 +168,11 @@ def recommended_datasets()->dict:
|
|||||||
if Datasets[ds]['recommended']:
|
if Datasets[ds]['recommended']:
|
||||||
datasets[ds]=True
|
datasets[ds]=True
|
||||||
return datasets
|
return datasets
|
||||||
|
|
||||||
#-------------------------------Authenticate against Hugging Face
|
#-------------------------------Authenticate against Hugging Face
|
||||||
def authenticate():
|
def authenticate():
|
||||||
print('''
|
print('''
|
||||||
To download the Stable Diffusion weight files from the official Hugging Face
|
To download the Stable Diffusion weight files from the official Hugging Face
|
||||||
repository, you need to read and accept the CreativeML Responsible AI license.
|
repository, you need to read and accept the CreativeML Responsible AI license.
|
||||||
|
|
||||||
This involves a few easy steps.
|
This involves a few easy steps.
|
||||||
@ -203,18 +205,18 @@ This involves a few easy steps.
|
|||||||
access_token = HfFolder.get_token()
|
access_token = HfFolder.get_token()
|
||||||
if access_token is not None:
|
if access_token is not None:
|
||||||
print('found')
|
print('found')
|
||||||
|
|
||||||
if access_token is None:
|
if access_token is None:
|
||||||
print('not found')
|
print('not found')
|
||||||
print('''
|
print('''
|
||||||
4. Thank you! The last step is to enter your HuggingFace access token so that
|
4. Thank you! The last step is to enter your HuggingFace access token so that
|
||||||
this script is authorized to initiate the download. Go to the access tokens
|
this script is authorized to initiate the download. Go to the access tokens
|
||||||
page of your Hugging Face account and create a token by clicking the
|
page of your Hugging Face account and create a token by clicking the
|
||||||
"New token" button:
|
"New token" button:
|
||||||
|
|
||||||
https://huggingface.co/settings/tokens
|
https://huggingface.co/settings/tokens
|
||||||
|
|
||||||
(You can enter anything you like in the token creation field marked "Name".
|
(You can enter anything you like in the token creation field marked "Name".
|
||||||
"Role" should be "read").
|
"Role" should be "read").
|
||||||
|
|
||||||
Now copy the token to your clipboard and paste it here: '''
|
Now copy the token to your clipboard and paste it here: '''
|
||||||
@ -235,7 +237,7 @@ def migrate_models_ckpt():
|
|||||||
if rename:
|
if rename:
|
||||||
print(f'model.ckpt => {new_name}')
|
print(f'model.ckpt => {new_name}')
|
||||||
os.replace(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
|
os.replace(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_weight_datasets(models:dict, access_token:str):
|
def download_weight_datasets(models:dict, access_token:str):
|
||||||
migrate_models_ckpt()
|
migrate_models_ckpt()
|
||||||
@ -262,9 +264,9 @@ def download_weight_datasets(models:dict, access_token:str):
|
|||||||
|
|
||||||
HfFolder.save_token(access_token)
|
HfFolder.save_token(access_token)
|
||||||
keys = ', '.join(successful.keys())
|
keys = ', '.join(successful.keys())
|
||||||
print(f'Successfully installed {keys}')
|
print(f'Successfully installed {keys}')
|
||||||
return successful
|
return successful
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
||||||
model_dest = os.path.join(model_dir, model_name)
|
model_dest = os.path.join(model_dir, model_name)
|
||||||
@ -275,7 +277,7 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
|
|||||||
header = {"Authorization": f'Bearer {access_token}'} if access_token else {}
|
header = {"Authorization": f'Bearer {access_token}'} if access_token else {}
|
||||||
open_mode = 'wb'
|
open_mode = 'wb'
|
||||||
exist_size = 0
|
exist_size = 0
|
||||||
|
|
||||||
if os.path.exists(model_dest):
|
if os.path.exists(model_dest):
|
||||||
exist_size = os.path.getsize(model_dest)
|
exist_size = os.path.getsize(model_dest)
|
||||||
header['Range'] = f'bytes={exist_size}-'
|
header['Range'] = f'bytes={exist_size}-'
|
||||||
@ -283,7 +285,7 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t
|
|||||||
|
|
||||||
resp = requests.get(url, headers=header, stream=True)
|
resp = requests.get(url, headers=header, stream=True)
|
||||||
total = int(resp.headers.get('content-length', 0))
|
total = int(resp.headers.get('content-length', 0))
|
||||||
|
|
||||||
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
if resp.status_code==416: # "range not satisfiable", which means nothing to return
|
||||||
print(f'* {model_name}: complete file found. Skipping.')
|
print(f'* {model_name}: complete file found. Skipping.')
|
||||||
return True
|
return True
|
||||||
@ -331,12 +333,64 @@ def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
|||||||
print(f'Error downloading {label} model')
|
print(f'Error downloading {label} model')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
|
def download_diffusers(models: Dict, full_precision: bool):
|
||||||
|
# This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands,
|
||||||
|
# which moves a bunch of stuff.
|
||||||
|
# We can be more complete after we know it won't be all merge conflicts.
|
||||||
|
diffusers_repos = {
|
||||||
|
'CompVis/stable-diffusion-v1-4-original': 'CompVis/stable-diffusion-v1-4',
|
||||||
|
'runwayml/stable-diffusion-v1-5': 'runwayml/stable-diffusion-v1-5',
|
||||||
|
'runwayml/stable-diffusion-inpainting': 'runwayml/stable-diffusion-inpainting',
|
||||||
|
'hakurei/waifu-diffusion-v1-3': 'hakurei/waifu-diffusion'
|
||||||
|
}
|
||||||
|
vae_repos = {
|
||||||
|
'stabilityai/sd-vae-ft-mse-original': 'stabilityai/sd-vae-ft-mse',
|
||||||
|
}
|
||||||
|
precision_args = {}
|
||||||
|
if not full_precision:
|
||||||
|
precision_args.update(revision='fp16')
|
||||||
|
|
||||||
|
for model_name, model in models.items():
|
||||||
|
repo_id = model['repo_id']
|
||||||
|
if repo_id in vae_repos:
|
||||||
|
print(f" * Downloading diffusers VAE {model_name}...")
|
||||||
|
# TODO: can we autodetect when a repo has no fp16 revision?
|
||||||
|
AutoencoderKL.from_pretrained(repo_id)
|
||||||
|
elif repo_id not in diffusers_repos:
|
||||||
|
print(f" * Downloading diffusers {model_name}...")
|
||||||
|
StableDiffusionPipeline.from_pretrained(repo_id, **precision_args)
|
||||||
|
else:
|
||||||
|
warnings.warn(f" ⚠ FIXME: add diffusers repo for {repo_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def download_diffusers_in_config(config_path: Path, full_precision: bool):
|
||||||
|
# This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands,
|
||||||
|
# which moves a bunch of stuff.
|
||||||
|
# We can be more complete after we know it won't be all merge conflicts.
|
||||||
|
if not is_huggingface_authenticated():
|
||||||
|
print("*⚠ No Hugging Face access token; some downloads may be blocked.")
|
||||||
|
|
||||||
|
precision = 'full' if full_precision else 'float16'
|
||||||
|
cache = ModelCache(OmegaConf.load(config_path), precision=precision,
|
||||||
|
device_type='cpu', max_loaded_models=1)
|
||||||
|
for model_name in cache.list_models():
|
||||||
|
# TODO: download model without loading it.
|
||||||
|
# https://github.com/huggingface/diffusers/issues/1301
|
||||||
|
model_config = cache.config[model_name]
|
||||||
|
if model_config.get('format') == 'diffusers':
|
||||||
|
print(f" * Downloading diffusers {model_name}...")
|
||||||
|
cache.get_model(model_name)
|
||||||
|
cache.offload_model(model_name)
|
||||||
|
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def update_config_file(successfully_downloaded:dict,opt:dict):
|
def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||||
config_file = opt.config_file or Default_config_file
|
config_file = opt.config_file or Default_config_file
|
||||||
config_file = os.path.normpath(os.path.join(Globals.root,config_file))
|
config_file = os.path.normpath(os.path.join(Globals.root,config_file))
|
||||||
|
|
||||||
yaml = new_config_file_contents(successfully_downloaded,config_file)
|
yaml = new_config_file_contents(successfully_downloaded,config_file)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -355,8 +409,8 @@ def update_config_file(successfully_downloaded:dict,opt:dict):
|
|||||||
|
|
||||||
print(f'Successfully created new configuration file {config_file}')
|
print(f'Successfully created new configuration file {config_file}')
|
||||||
|
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str:
|
def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str:
|
||||||
if os.path.exists(config_file):
|
if os.path.exists(config_file):
|
||||||
conf = OmegaConf.load(config_file)
|
conf = OmegaConf.load(config_file)
|
||||||
@ -366,19 +420,19 @@ def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str
|
|||||||
# find the VAE file, if there is one
|
# find the VAE file, if there is one
|
||||||
vaes = {}
|
vaes = {}
|
||||||
default_selected = False
|
default_selected = False
|
||||||
|
|
||||||
for model in successfully_downloaded:
|
for model in successfully_downloaded:
|
||||||
a = Datasets[model]['config'].split('/')
|
a = Datasets[model]['config'].split('/')
|
||||||
if a[0] != 'VAE':
|
if a[0] != 'VAE':
|
||||||
continue
|
continue
|
||||||
vae_target = a[1] if len(a)>1 else 'default'
|
vae_target = a[1] if len(a)>1 else 'default'
|
||||||
vaes[vae_target] = Datasets[model]['file']
|
vaes[vae_target] = Datasets[model]['file']
|
||||||
|
|
||||||
for model in successfully_downloaded:
|
for model in successfully_downloaded:
|
||||||
if Datasets[model]['config'].startswith('VAE'): # skip VAE entries
|
if Datasets[model]['config'].startswith('VAE'): # skip VAE entries
|
||||||
continue
|
continue
|
||||||
stanza = conf[model] if model in conf else { }
|
stanza = conf[model] if model in conf else { }
|
||||||
|
|
||||||
stanza['description'] = Datasets[model]['description']
|
stanza['description'] = Datasets[model]['description']
|
||||||
stanza['weights'] = os.path.join(Model_dir,Weights_dir,Datasets[model]['file'])
|
stanza['weights'] = os.path.join(Model_dir,Weights_dir,Datasets[model]['file'])
|
||||||
stanza['config'] = os.path.normpath(os.path.join(SD_Configs, Datasets[model]['config']))
|
stanza['config'] = os.path.normpath(os.path.join(SD_Configs, Datasets[model]['config']))
|
||||||
@ -397,14 +451,14 @@ def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str
|
|||||||
default_selected = True
|
default_selected = True
|
||||||
conf[model] = stanza
|
conf[model] = stanza
|
||||||
return OmegaConf.to_yaml(conf)
|
return OmegaConf.to_yaml(conf)
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
# this will preload the Bert tokenizer fles
|
# this will preload the Bert tokenizer fles
|
||||||
def download_bert():
|
def download_bert():
|
||||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
from transformers import BertTokenizerFast
|
||||||
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
||||||
print('...success',file=sys.stderr)
|
print('...success',file=sys.stderr)
|
||||||
|
|
||||||
@ -467,7 +521,7 @@ def download_clipseg():
|
|||||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||||
model_dest = os.path.join(Globals.root,'models/clipseg/clipseg_weights')
|
model_dest = os.path.join(Globals.root,'models/clipseg/clipseg_weights')
|
||||||
weights_zip = 'models/clipseg/weights.zip'
|
weights_zip = 'models/clipseg/weights.zip'
|
||||||
|
|
||||||
if not os.path.exists(model_dest):
|
if not os.path.exists(model_dest):
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'):
|
if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'):
|
||||||
@ -586,7 +640,7 @@ def select_outputs(root:str,yes_to_all:bool=False):
|
|||||||
#-------------------------------------
|
#-------------------------------------
|
||||||
def initialize_rootdir(root:str,yes_to_all:bool=False):
|
def initialize_rootdir(root:str,yes_to_all:bool=False):
|
||||||
assert os.path.exists('./configs'),'Run this script from within the InvokeAI source code directory, "InvokeAI" or the runtime directory "invokeai".'
|
assert os.path.exists('./configs'),'Run this script from within the InvokeAI source code directory, "InvokeAI" or the runtime directory "invokeai".'
|
||||||
|
|
||||||
print(f'** INITIALIZING INVOKEAI RUNTIME DIRECTORY **')
|
print(f'** INITIALIZING INVOKEAI RUNTIME DIRECTORY **')
|
||||||
root_selected = False
|
root_selected = False
|
||||||
while not root_selected:
|
while not root_selected:
|
||||||
@ -670,7 +724,7 @@ def initialize_rootdir(root:str,yes_to_all:bool=False):
|
|||||||
# -Ak_euler_a -C10.0
|
# -Ak_euler_a -C10.0
|
||||||
#
|
#
|
||||||
''')
|
''')
|
||||||
|
|
||||||
#-------------------------------------
|
#-------------------------------------
|
||||||
class ProgressBar():
|
class ProgressBar():
|
||||||
def __init__(self,model_name='file'):
|
def __init__(self,model_name='file'):
|
||||||
@ -727,6 +781,12 @@ def main():
|
|||||||
if opt.interactive:
|
if opt.interactive:
|
||||||
print('** DOWNLOADING DIFFUSION WEIGHTS **')
|
print('** DOWNLOADING DIFFUSION WEIGHTS **')
|
||||||
download_weights(opt)
|
download_weights(opt)
|
||||||
|
else:
|
||||||
|
config_path = Path(opt.config_file or Default_config_file)
|
||||||
|
if config_path.exists():
|
||||||
|
download_diffusers_in_config(config_path, full_precision=opt.full_precision)
|
||||||
|
else:
|
||||||
|
print("*⚠ No config file found; downloading no weights.")
|
||||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||||
download_bert()
|
download_bert()
|
||||||
download_clip()
|
download_clip()
|
||||||
@ -741,7 +801,7 @@ def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f'\nA problem occurred during initialization.\nThe error was: "{str(e)}"')
|
print(f'\nA problem occurred during initialization.\nThe error was: "{str(e)}"')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
#-------------------------------------
|
#-------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user