From 8ce1ae550ba617635b6cedfeda468a9f10977d5b Mon Sep 17 00:00:00 2001 From: Kevin Turner <83819+keturn@users.noreply.github.com> Date: Sat, 26 Nov 2022 10:58:12 -0800 Subject: [PATCH] configure: try to download models even without token Models in the CompVis and stabilityai repos no longer require them. (But runwayml still does.) --- scripts/configure_invokeai.py | 134 ++++++++++++++++++++++++---------- 1 file changed, 97 insertions(+), 37 deletions(-) diff --git a/scripts/configure_invokeai.py b/scripts/configure_invokeai.py index 2bfefaa28c..7c410e3ead 100644 --- a/scripts/configure_invokeai.py +++ b/scripts/configure_invokeai.py @@ -8,26 +8,28 @@ # print('Loading Python libraries...\n') import argparse -import sys import os import re -import warnings import shutil -from urllib import request -from tqdm import tqdm -from omegaconf import OmegaConf -from huggingface_hub import HfFolder, hf_hub_url +import sys +import traceback +import warnings from pathlib import Path +from typing import Dict +from urllib import request + +import requests +import transformers +from diffusers import StableDiffusionPipeline, AutoencoderKL from getpass_asterisk import getpass_asterisk +from huggingface_hub import HfFolder, hf_hub_url, whoami as hf_whoami +from omegaconf import OmegaConf +from tqdm import tqdm from transformers import CLIPTokenizer, CLIPTextModel + from ldm.invoke.globals import Globals from ldm.invoke.readline import generic_completer -import traceback -import requests -import clip -import transformers -import warnings warnings.filterwarnings('ignore') import torch transformers.logging.set_verbosity_error() @@ -65,7 +67,7 @@ this program and resume later.\n''' def postscript(): print( '''\n** Model Installation Successful **\nYou're all set! You may now launch InvokeAI using one of these two commands: -Web version: +Web version: python scripts/invoke.py --web (connect to http://localhost:9090) Command-line version: python scripts/invoke.py @@ -127,7 +129,7 @@ def select_datasets(action:str): if action == 'customized': print(''' -Choose the weight file(s) you wish to download. Before downloading you +Choose the weight file(s) you wish to download. Before downloading you will be given the option to view and change your selections. ''' ) @@ -142,7 +144,7 @@ will be given the option to view and change your selections. if Datasets[ds]['recommended']: datasets[ds]=counter counter += 1 - + print('The following weight files will be downloaded:') for ds in datasets: dflt = '*' if dflt is None else '' @@ -166,11 +168,11 @@ def recommended_datasets()->dict: if Datasets[ds]['recommended']: datasets[ds]=True return datasets - + #-------------------------------Authenticate against Hugging Face def authenticate(): print(''' -To download the Stable Diffusion weight files from the official Hugging Face +To download the Stable Diffusion weight files from the official Hugging Face repository, you need to read and accept the CreativeML Responsible AI license. This involves a few easy steps. @@ -203,18 +205,18 @@ This involves a few easy steps. access_token = HfFolder.get_token() if access_token is not None: print('found') - + if access_token is None: print('not found') print(''' 4. Thank you! The last step is to enter your HuggingFace access token so that this script is authorized to initiate the download. Go to the access tokens - page of your Hugging Face account and create a token by clicking the + page of your Hugging Face account and create a token by clicking the "New token" button: https://huggingface.co/settings/tokens - (You can enter anything you like in the token creation field marked "Name". + (You can enter anything you like in the token creation field marked "Name". "Role" should be "read"). Now copy the token to your clipboard and paste it here: ''' @@ -235,7 +237,7 @@ def migrate_models_ckpt(): if rename: print(f'model.ckpt => {new_name}') os.replace(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name)) - + #--------------------------------------------- def download_weight_datasets(models:dict, access_token:str): migrate_models_ckpt() @@ -262,9 +264,9 @@ def download_weight_datasets(models:dict, access_token:str): HfFolder.save_token(access_token) keys = ', '.join(successful.keys()) - print(f'Successfully installed {keys}') + print(f'Successfully installed {keys}') return successful - + #--------------------------------------------- def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool: model_dest = os.path.join(model_dir, model_name) @@ -275,7 +277,7 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t header = {"Authorization": f'Bearer {access_token}'} if access_token else {} open_mode = 'wb' exist_size = 0 - + if os.path.exists(model_dest): exist_size = os.path.getsize(model_dest) header['Range'] = f'bytes={exist_size}-' @@ -283,7 +285,7 @@ def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_t resp = requests.get(url, headers=header, stream=True) total = int(resp.headers.get('content-length', 0)) - + if resp.status_code==416: # "range not satisfiable", which means nothing to return print(f'* {model_name}: complete file found. Skipping.') return True @@ -331,12 +333,64 @@ def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'): print(f'Error downloading {label} model') print(traceback.format_exc()) - + +#--------------------------------------------- +def download_diffusers(models: Dict, full_precision: bool): + # This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands, + # which moves a bunch of stuff. + # We can be more complete after we know it won't be all merge conflicts. + diffusers_repos = { + 'CompVis/stable-diffusion-v1-4-original': 'CompVis/stable-diffusion-v1-4', + 'runwayml/stable-diffusion-v1-5': 'runwayml/stable-diffusion-v1-5', + 'runwayml/stable-diffusion-inpainting': 'runwayml/stable-diffusion-inpainting', + 'hakurei/waifu-diffusion-v1-3': 'hakurei/waifu-diffusion' + } + vae_repos = { + 'stabilityai/sd-vae-ft-mse-original': 'stabilityai/sd-vae-ft-mse', + } + precision_args = {} + if not full_precision: + precision_args.update(revision='fp16') + + for model_name, model in models.items(): + repo_id = model['repo_id'] + if repo_id in vae_repos: + print(f" * Downloading diffusers VAE {model_name}...") + # TODO: can we autodetect when a repo has no fp16 revision? + AutoencoderKL.from_pretrained(repo_id) + elif repo_id not in diffusers_repos: + print(f" * Downloading diffusers {model_name}...") + StableDiffusionPipeline.from_pretrained(repo_id, **precision_args) + else: + warnings.warn(f" ⚠ FIXME: add diffusers repo for {repo_id}") + continue + + +def download_diffusers_in_config(config_path: Path, full_precision: bool): + # This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands, + # which moves a bunch of stuff. + # We can be more complete after we know it won't be all merge conflicts. + if not is_huggingface_authenticated(): + print("*⚠ No Hugging Face access token; some downloads may be blocked.") + + precision = 'full' if full_precision else 'float16' + cache = ModelCache(OmegaConf.load(config_path), precision=precision, + device_type='cpu', max_loaded_models=1) + for model_name in cache.list_models(): + # TODO: download model without loading it. + # https://github.com/huggingface/diffusers/issues/1301 + model_config = cache.config[model_name] + if model_config.get('format') == 'diffusers': + print(f" * Downloading diffusers {model_name}...") + cache.get_model(model_name) + cache.offload_model(model_name) + + #--------------------------------------------- def update_config_file(successfully_downloaded:dict,opt:dict): config_file = opt.config_file or Default_config_file config_file = os.path.normpath(os.path.join(Globals.root,config_file)) - + yaml = new_config_file_contents(successfully_downloaded,config_file) try: @@ -355,8 +409,8 @@ def update_config_file(successfully_downloaded:dict,opt:dict): print(f'Successfully created new configuration file {config_file}') - -#--------------------------------------------- + +#--------------------------------------------- def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str: if os.path.exists(config_file): conf = OmegaConf.load(config_file) @@ -366,19 +420,19 @@ def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str # find the VAE file, if there is one vaes = {} default_selected = False - + for model in successfully_downloaded: a = Datasets[model]['config'].split('/') if a[0] != 'VAE': continue vae_target = a[1] if len(a)>1 else 'default' vaes[vae_target] = Datasets[model]['file'] - + for model in successfully_downloaded: if Datasets[model]['config'].startswith('VAE'): # skip VAE entries continue stanza = conf[model] if model in conf else { } - + stanza['description'] = Datasets[model]['description'] stanza['weights'] = os.path.join(Model_dir,Weights_dir,Datasets[model]['file']) stanza['config'] = os.path.normpath(os.path.join(SD_Configs, Datasets[model]['config'])) @@ -397,14 +451,14 @@ def new_config_file_contents(successfully_downloaded:dict, config_file:str)->str default_selected = True conf[model] = stanza return OmegaConf.to_yaml(conf) - + #--------------------------------------------- # this will preload the Bert tokenizer fles def download_bert(): print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr) with warnings.catch_warnings(): warnings.filterwarnings('ignore', category=DeprecationWarning) - from transformers import BertTokenizerFast, AutoFeatureExtractor + from transformers import BertTokenizerFast download_from_hf(BertTokenizerFast,'bert-base-uncased') print('...success',file=sys.stderr) @@ -467,7 +521,7 @@ def download_clipseg(): model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' model_dest = os.path.join(Globals.root,'models/clipseg/clipseg_weights') weights_zip = 'models/clipseg/weights.zip' - + if not os.path.exists(model_dest): os.makedirs(os.path.dirname(model_dest), exist_ok=True) if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'): @@ -586,7 +640,7 @@ def select_outputs(root:str,yes_to_all:bool=False): #------------------------------------- def initialize_rootdir(root:str,yes_to_all:bool=False): assert os.path.exists('./configs'),'Run this script from within the InvokeAI source code directory, "InvokeAI" or the runtime directory "invokeai".' - + print(f'** INITIALIZING INVOKEAI RUNTIME DIRECTORY **') root_selected = False while not root_selected: @@ -670,7 +724,7 @@ def initialize_rootdir(root:str,yes_to_all:bool=False): # -Ak_euler_a -C10.0 # ''') - + #------------------------------------- class ProgressBar(): def __init__(self,model_name='file'): @@ -727,6 +781,12 @@ def main(): if opt.interactive: print('** DOWNLOADING DIFFUSION WEIGHTS **') download_weights(opt) + else: + config_path = Path(opt.config_file or Default_config_file) + if config_path.exists(): + download_diffusers_in_config(config_path, full_precision=opt.full_precision) + else: + print("*⚠ No config file found; downloading no weights.") print('\n** DOWNLOADING SUPPORT MODELS **') download_bert() download_clip() @@ -741,7 +801,7 @@ def main(): except Exception as e: print(f'\nA problem occurred during initialization.\nThe error was: "{str(e)}"') print(traceback.format_exc()) - + #------------------------------------- if __name__ == '__main__': main()