mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
configure: try to download models even without token
Models in the CompVis and stabilityai repos no longer require them. (But runwayml still does.)
This commit is contained in:
parent
c21660a6df
commit
8ce1ae550b
@ -8,26 +8,28 @@
|
|||||||
#
|
#
|
||||||
print('Loading Python libraries...\n')
|
print('Loading Python libraries...\n')
|
||||||
import argparse
|
import argparse
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import warnings
|
|
||||||
import shutil
|
import shutil
|
||||||
from urllib import request
|
import sys
|
||||||
from tqdm import tqdm
|
import traceback
|
||||||
from omegaconf import OmegaConf
|
import warnings
|
||||||
from huggingface_hub import HfFolder, hf_hub_url
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict
|
||||||
|
from urllib import request
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import transformers
|
||||||
|
from diffusers import StableDiffusionPipeline, AutoencoderKL
|
||||||
from getpass_asterisk import getpass_asterisk
|
from getpass_asterisk import getpass_asterisk
|
||||||
|
from huggingface_hub import HfFolder, hf_hub_url, whoami as hf_whoami
|
||||||
|
from omegaconf import OmegaConf
|
||||||
|
from tqdm import tqdm
|
||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
|
|
||||||
from ldm.invoke.globals import Globals
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.invoke.readline import generic_completer
|
from ldm.invoke.readline import generic_completer
|
||||||
|
|
||||||
import traceback
|
|
||||||
import requests
|
|
||||||
import clip
|
|
||||||
import transformers
|
|
||||||
import warnings
|
|
||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
import torch
|
import torch
|
||||||
transformers.logging.set_verbosity_error()
|
transformers.logging.set_verbosity_error()
|
||||||
@ -332,6 +334,58 @@ def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
|||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
#---------------------------------------------
|
||||||
|
def download_diffusers(models: Dict, full_precision: bool):
|
||||||
|
# This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands,
|
||||||
|
# which moves a bunch of stuff.
|
||||||
|
# We can be more complete after we know it won't be all merge conflicts.
|
||||||
|
diffusers_repos = {
|
||||||
|
'CompVis/stable-diffusion-v1-4-original': 'CompVis/stable-diffusion-v1-4',
|
||||||
|
'runwayml/stable-diffusion-v1-5': 'runwayml/stable-diffusion-v1-5',
|
||||||
|
'runwayml/stable-diffusion-inpainting': 'runwayml/stable-diffusion-inpainting',
|
||||||
|
'hakurei/waifu-diffusion-v1-3': 'hakurei/waifu-diffusion'
|
||||||
|
}
|
||||||
|
vae_repos = {
|
||||||
|
'stabilityai/sd-vae-ft-mse-original': 'stabilityai/sd-vae-ft-mse',
|
||||||
|
}
|
||||||
|
precision_args = {}
|
||||||
|
if not full_precision:
|
||||||
|
precision_args.update(revision='fp16')
|
||||||
|
|
||||||
|
for model_name, model in models.items():
|
||||||
|
repo_id = model['repo_id']
|
||||||
|
if repo_id in vae_repos:
|
||||||
|
print(f" * Downloading diffusers VAE {model_name}...")
|
||||||
|
# TODO: can we autodetect when a repo has no fp16 revision?
|
||||||
|
AutoencoderKL.from_pretrained(repo_id)
|
||||||
|
elif repo_id not in diffusers_repos:
|
||||||
|
print(f" * Downloading diffusers {model_name}...")
|
||||||
|
StableDiffusionPipeline.from_pretrained(repo_id, **precision_args)
|
||||||
|
else:
|
||||||
|
warnings.warn(f" ⚠ FIXME: add diffusers repo for {repo_id}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
def download_diffusers_in_config(config_path: Path, full_precision: bool):
|
||||||
|
# This is a minimal implementation until https://github.com/invoke-ai/InvokeAI/pull/1490 lands,
|
||||||
|
# which moves a bunch of stuff.
|
||||||
|
# We can be more complete after we know it won't be all merge conflicts.
|
||||||
|
if not is_huggingface_authenticated():
|
||||||
|
print("*⚠ No Hugging Face access token; some downloads may be blocked.")
|
||||||
|
|
||||||
|
precision = 'full' if full_precision else 'float16'
|
||||||
|
cache = ModelCache(OmegaConf.load(config_path), precision=precision,
|
||||||
|
device_type='cpu', max_loaded_models=1)
|
||||||
|
for model_name in cache.list_models():
|
||||||
|
# TODO: download model without loading it.
|
||||||
|
# https://github.com/huggingface/diffusers/issues/1301
|
||||||
|
model_config = cache.config[model_name]
|
||||||
|
if model_config.get('format') == 'diffusers':
|
||||||
|
print(f" * Downloading diffusers {model_name}...")
|
||||||
|
cache.get_model(model_name)
|
||||||
|
cache.offload_model(model_name)
|
||||||
|
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def update_config_file(successfully_downloaded:dict,opt:dict):
|
def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||||
config_file = opt.config_file or Default_config_file
|
config_file = opt.config_file or Default_config_file
|
||||||
@ -404,7 +458,7 @@ def download_bert():
|
|||||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
from transformers import BertTokenizerFast
|
||||||
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
||||||
print('...success',file=sys.stderr)
|
print('...success',file=sys.stderr)
|
||||||
|
|
||||||
@ -727,6 +781,12 @@ def main():
|
|||||||
if opt.interactive:
|
if opt.interactive:
|
||||||
print('** DOWNLOADING DIFFUSION WEIGHTS **')
|
print('** DOWNLOADING DIFFUSION WEIGHTS **')
|
||||||
download_weights(opt)
|
download_weights(opt)
|
||||||
|
else:
|
||||||
|
config_path = Path(opt.config_file or Default_config_file)
|
||||||
|
if config_path.exists():
|
||||||
|
download_diffusers_in_config(config_path, full_precision=opt.full_precision)
|
||||||
|
else:
|
||||||
|
print("*⚠ No config file found; downloading no weights.")
|
||||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||||
download_bert()
|
download_bert()
|
||||||
download_clip()
|
download_clip()
|
||||||
|
Loading…
Reference in New Issue
Block a user