configure: try to download models even without token

Models in the CompVis and stabilityai repos no longer require them. (But runwayml still does.)
This commit is contained in:
Kevin Turner 2022-11-26 10:58:12 -08:00
parent c21660a6df
commit 8ce1ae550b

View File

@ -8,26 +8,28 @@
# #
print('Loading Python libraries...\n') print('Loading Python libraries...\n')
import argparse import argparse
import sys
import os import os
import re import re
import warnings
import shutil import shutil
from urllib import request import sys
from tqdm import tqdm import traceback
from omegaconf import OmegaConf import warnings
from huggingface_hub import HfFolder, hf_hub_url
from pathlib import Path from pathlib import Path
from typing import Dict
from urllib import request
import requests
import transformers
from diffusers import StableDiffusionPipeline, AutoencoderKL
from getpass_asterisk import getpass_asterisk from getpass_asterisk import getpass_asterisk
from huggingface_hub import HfFolder, hf_hub_url, whoami as hf_whoami
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
from ldm.invoke.globals import Globals from ldm.invoke.globals import Globals
from ldm.invoke.readline import generic_completer from ldm.invoke.readline import generic_completer
import traceback
import requests
import clip
import transformers
import warnings
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
import torch import torch
transformers.logging.set_verbosity_error() transformers.logging.set_verbosity_error()
@ -332,6 +334,58 @@ def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
print(traceback.format_exc()) print(traceback.format_exc())
#---------------------------------------------
def download_diffusers(models: Dict, full_precision: bool):
# This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands,
# which moves a bunch of stuff.
# We can be more complete after we know it won't be all merge conflicts.
diffusers_repos = {
'CompVis/stable-diffusion-v1-4-original': 'CompVis/stable-diffusion-v1-4',
'runwayml/stable-diffusion-v1-5': 'runwayml/stable-diffusion-v1-5',
'runwayml/stable-diffusion-inpainting': 'runwayml/stable-diffusion-inpainting',
'hakurei/waifu-diffusion-v1-3': 'hakurei/waifu-diffusion'
}
vae_repos = {
'stabilityai/sd-vae-ft-mse-original': 'stabilityai/sd-vae-ft-mse',
}
precision_args = {}
if not full_precision:
precision_args.update(revision='fp16')
for model_name, model in models.items():
repo_id = model['repo_id']
if repo_id in vae_repos:
print(f" * Downloading diffusers VAE {model_name}...")
# TODO: can we autodetect when a repo has no fp16 revision?
AutoencoderKL.from_pretrained(repo_id)
elif repo_id not in diffusers_repos:
print(f" * Downloading diffusers {model_name}...")
StableDiffusionPipeline.from_pretrained(repo_id, **precision_args)
else:
warnings.warn(f" ⚠ FIXME: add diffusers repo for {repo_id}")
continue
def download_diffusers_in_config(config_path: Path, full_precision: bool):
# This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands,
# which moves a bunch of stuff.
# We can be more complete after we know it won't be all merge conflicts.
if not is_huggingface_authenticated():
print("*⚠ No Hugging Face access token; some downloads may be blocked.")
precision = 'full' if full_precision else 'float16'
cache = ModelCache(OmegaConf.load(config_path), precision=precision,
device_type='cpu', max_loaded_models=1)
for model_name in cache.list_models():
# TODO: download model without loading it.
# https://github.com/huggingface/diffusers/issues/1301
model_config = cache.config[model_name]
if model_config.get('format') == 'diffusers':
print(f" * Downloading diffusers {model_name}...")
cache.get_model(model_name)
cache.offload_model(model_name)
#--------------------------------------------- #---------------------------------------------
def update_config_file(successfully_downloaded:dict,opt:dict): def update_config_file(successfully_downloaded:dict,opt:dict):
config_file = opt.config_file or Default_config_file config_file = opt.config_file or Default_config_file
@ -404,7 +458,7 @@ def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr) print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
from transformers import BertTokenizerFast, AutoFeatureExtractor from transformers import BertTokenizerFast
download_from_hf(BertTokenizerFast,'bert-base-uncased') download_from_hf(BertTokenizerFast,'bert-base-uncased')
print('...success',file=sys.stderr) print('...success',file=sys.stderr)
@ -727,6 +781,12 @@ def main():
if opt.interactive: if opt.interactive:
print('** DOWNLOADING DIFFUSION WEIGHTS **') print('** DOWNLOADING DIFFUSION WEIGHTS **')
download_weights(opt) download_weights(opt)
else:
config_path = Path(opt.config_file or Default_config_file)
if config_path.exists():
download_diffusers_in_config(config_path, full_precision=opt.full_precision)
else:
print("*⚠ No config file found; downloading no weights.")
print('\n** DOWNLOADING SUPPORT MODELS **') print('\n** DOWNLOADING SUPPORT MODELS **')
download_bert() download_bert()
download_clip() download_clip()