mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model paths fixed, codeformer needs attention
This commit is contained in:
parent
4c035ad4ae
commit
274b276133
@ -27,6 +27,7 @@ from pytorch_lightning import seed_everything, logging
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.models.diffusion.ksampler import KSampler
|
||||
@ -220,8 +221,14 @@ class Generate:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
|
||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
|
||||
)
|
||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id,
|
||||
local_files_only=True,
|
||||
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
|
||||
)
|
||||
self.safety_checker.to(self.device)
|
||||
except Exception:
|
||||
print('** An error was encountered while installing the safety checker:')
|
||||
|
@ -389,6 +389,11 @@ class Args(object):
|
||||
|
||||
deprecated_group.add_argument('--laion400m')
|
||||
deprecated_group.add_argument('--weights') # deprecated
|
||||
model_group.add_argument(
|
||||
'--root_dir',
|
||||
default='.',
|
||||
help='Path to directory containing "models", "outputs" and "configs"'
|
||||
)
|
||||
model_group.add_argument(
|
||||
'--config',
|
||||
'-c',
|
||||
|
15
ldm/invoke/globals.py
Normal file
15
ldm/invoke/globals.py
Normal file
@ -0,0 +1,15 @@
|
||||
'''
|
||||
ldm.invoke.globals defines a small number of global variables that would
|
||||
otherwise have to be passed through long and complex call chains.
|
||||
|
||||
It defines a Namespace object named "Globals" that contains
|
||||
the attributes:
|
||||
|
||||
- root - the root directory under which "models" and "outputs" can be found
|
||||
'''
|
||||
|
||||
from argparse import Namespace
|
||||
|
||||
Globals = Namespace()
|
||||
Globals.root = '.'
|
||||
|
@ -19,6 +19,7 @@ from sys import getrefcount
|
||||
from omegaconf import OmegaConf
|
||||
from omegaconf.errors import ConfigAttributeError
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
DEFAULT_MAX_MODELS=2
|
||||
|
||||
@ -200,6 +201,9 @@ class ModelCache(object):
|
||||
width = mconfig.width
|
||||
height = mconfig.height
|
||||
|
||||
if not os.path.isabs(weights):
|
||||
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
||||
|
||||
print(f'>> Loading {model_name} from {weights}')
|
||||
|
||||
# for usage statistics
|
||||
@ -210,6 +214,8 @@ class ModelCache(object):
|
||||
tic = time.time()
|
||||
|
||||
# this does the work
|
||||
if not os.path.isabs(config):
|
||||
config = os.path.join(Globals.root,config)
|
||||
c = OmegaConf.load(config)
|
||||
with open(weights,'rb') as f:
|
||||
weight_bytes = f.read()
|
||||
@ -228,6 +234,8 @@ class ModelCache(object):
|
||||
|
||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||
if vae:
|
||||
if not os.path.isabs(vae):
|
||||
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||
if os.path.exists(vae):
|
||||
print(f' | Loading VAE weights from: {vae}')
|
||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||
|
@ -3,13 +3,18 @@ import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
import sys
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
class CodeFormerRestoration():
|
||||
def __init__(self,
|
||||
codeformer_dir='ldm/invoke/restoration/codeformer',
|
||||
codeformer_model_path='weights/codeformer.pth') -> None:
|
||||
codeformer_dir='models/codeformer',
|
||||
codeformer_model_path='codeformer.pth') -> None:
|
||||
|
||||
if not os.path.isabs(codeformer_dir):
|
||||
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
|
||||
|
||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
@ -33,9 +38,20 @@ class CodeFormerRestoration():
|
||||
|
||||
cf_class = CodeFormer
|
||||
|
||||
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
|
||||
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/invoke/restoration/codeformer/weights'), progress=True)
|
||||
cf = cf_class(
|
||||
dim_embd=512,
|
||||
codebook_size=1024,
|
||||
n_head=8,
|
||||
n_layers=9,
|
||||
connect_list=['32', '64', '128', '256']
|
||||
).to(device)
|
||||
|
||||
# note that this file should already be downloaded and cached at
|
||||
# this point
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url,
|
||||
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
|
||||
progress=True
|
||||
)
|
||||
checkpoint = torch.load(checkpoint_path)['params_ema']
|
||||
cf.load_state_dict(checkpoint)
|
||||
cf.eval()
|
||||
|
@ -3,6 +3,7 @@ import warnings
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@ -10,18 +11,21 @@ from PIL import Image
|
||||
class GFPGAN():
|
||||
def __init__(
|
||||
self,
|
||||
gfpgan_model_path='./models/gfpgan/GFPGANv1.4.pth') -> None:
|
||||
|
||||
self.model_path = os.path.join(gfpgan_model_path)
|
||||
gfpgan_model_path='models/gfpgan/GFPGANv1.4.pth'
|
||||
) -> None:
|
||||
|
||||
if not os.path.isabs(gfpgan_model_path):
|
||||
gfpgan_model_path=os.path.join(Globals.root,gfpgan_model_path)
|
||||
self.model_path = gfpgan_model_path
|
||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||
|
||||
|
||||
if not self.gfpgan_model_exists:
|
||||
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
|
||||
return None
|
||||
|
||||
|
||||
def model_exists(self):
|
||||
return os.path.isfile(self.model_path)
|
||||
|
||||
|
||||
def process(self, image, strength: float, seed: str = None):
|
||||
if seed is not None:
|
||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||
|
@ -1,5 +1,5 @@
|
||||
import math
|
||||
|
||||
import os.path
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -8,6 +8,7 @@ from einops import rearrange, repeat
|
||||
from transformers import CLIPTokenizer, CLIPTextModel
|
||||
import kornia
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
from ldm.invoke.globals import Globals
|
||||
|
||||
from ldm.modules.x_transformer import (
|
||||
Encoder,
|
||||
@ -98,21 +99,19 @@ class BERTTokenizer(AbstractEncoder):
|
||||
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
|
||||
def __init__(
|
||||
self, device=choose_torch_device(), vq_interface=True, max_length=77
|
||||
self, device=choose_torch_device(), vq_interface=True, max_length=77
|
||||
):
|
||||
super().__init__()
|
||||
from transformers import (
|
||||
BertTokenizerFast,
|
||||
) # TODO: add to reuquirements
|
||||
)
|
||||
|
||||
# Modified to allow to run on non-internet connected compute nodes.
|
||||
# Model needs to be loaded into cache from an internet-connected machine
|
||||
# by running:
|
||||
# from transformers import BertTokenizerFast
|
||||
# BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
cache = os.path.join(Globals.root,'models/bert-base-uncased')
|
||||
try:
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained(
|
||||
'bert-base-uncased', local_files_only=True
|
||||
'bert-base-uncased',
|
||||
cache_dir=cache,
|
||||
local_files_only=True
|
||||
)
|
||||
except OSError:
|
||||
raise SystemExit(
|
||||
@ -150,14 +149,14 @@ class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device=choose_torch_device(),
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
self,
|
||||
n_embed,
|
||||
n_layer,
|
||||
vocab_size=30522,
|
||||
max_seq_len=77,
|
||||
device=choose_torch_device(),
|
||||
use_tokenizer=True,
|
||||
embedding_dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
@ -245,10 +244,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
):
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||
version, local_files_only=True
|
||||
version,
|
||||
cache_dir=os.path.join(Globals.root,'models',version),
|
||||
local_files_only=True
|
||||
)
|
||||
self.transformer = CLIPTextModel.from_pretrained(
|
||||
version, local_files_only=True
|
||||
version,
|
||||
cache_dir=os.path.join(Globals.root,'models',version),
|
||||
local_files_only=True
|
||||
)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
|
@ -1,2 +0,0 @@
|
||||
See docs/features/INSTALLING_MODELS.md for how to populate this
|
||||
directory with one or more Stable Diffusion model weight files.
|
@ -11,9 +11,9 @@ import time
|
||||
import traceback
|
||||
import yaml
|
||||
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
|
||||
sys.path.append('.') # corrects a weird problem on Macs
|
||||
from ldm.invoke.globals import Globals
|
||||
from ldm.invoke.prompt_parser import PromptParser
|
||||
from ldm.invoke.readline import get_completer
|
||||
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
|
||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||
@ -47,6 +47,9 @@ def main():
|
||||
print('--max_loaded_models must be >= 1; using 1')
|
||||
args.max_loaded_models = 1
|
||||
|
||||
# alert - setting a global here
|
||||
Globals.root=os.path.expanduser(args.root_dir)
|
||||
|
||||
from ldm.generate import Generate
|
||||
|
||||
# these two lines prevent a horrible warning message from appearing
|
||||
@ -57,10 +60,16 @@ def main():
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
gfpgan,codeformer,esrgan = load_face_restoration(opt)
|
||||
|
||||
# make sure the output directory exists
|
||||
# normalize the outdir relative to root and make sure it exists
|
||||
if not os.path.abspath(opt.outdir):
|
||||
opt.outdir=os.path.normpath(os.path.join(Globals.root,opt.outdir))
|
||||
if not os.path.exists(opt.outdir):
|
||||
os.makedirs(opt.outdir)
|
||||
|
||||
# normalize the config directory relative to root
|
||||
if not os.path.abspath(opt.conf):
|
||||
opt.conf=os.path.normpath(os.path.join(Globals.root,opt.conf))
|
||||
|
||||
# load the infile as a list of lines
|
||||
if opt.infile:
|
||||
try:
|
||||
@ -77,7 +86,7 @@ def main():
|
||||
# creating a Generate object:
|
||||
try:
|
||||
gen = Generate(
|
||||
conf = opt.conf,
|
||||
conf = os.path.join(Globals.root,opt.conf),
|
||||
model = opt.model,
|
||||
sampler_name = opt.sampler_name,
|
||||
embedding_path = opt.embedding_path,
|
||||
@ -128,6 +137,8 @@ def main_loop(gen, opt):
|
||||
doneAfterInFile = infile is not None
|
||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||
last_results = list()
|
||||
if not os.path.isabs(opt.conf):
|
||||
opt.conf = os.path.join(Globals.root,opt.conf)
|
||||
model_config = OmegaConf.load(opt.conf)
|
||||
|
||||
# The readline completer reads history from the .dream_history file located in the
|
||||
|
@ -1,5 +1,6 @@
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
||||
from ldm.modules.embedding_manager import EmbeddingManager
|
||||
from ldm.modules.globals import Globals
|
||||
|
||||
import argparse, os
|
||||
from functools import partial
|
||||
@ -51,6 +52,13 @@ if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--root_dir",
|
||||
type=str,
|
||||
default='.',
|
||||
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--manager_ckpts",
|
||||
type=str,
|
||||
@ -73,6 +81,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
Globals.root=args.root_dir
|
||||
|
||||
if args.use_bert:
|
||||
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
||||
|
103
scripts/preload_models.py
Normal file → Executable file
103
scripts/preload_models.py
Normal file → Executable file
@ -32,7 +32,9 @@ warnings.filterwarnings('ignore')
|
||||
#warnings.filterwarnings('ignore',category=UserWarning)
|
||||
|
||||
#--------------------------globals--
|
||||
Model_dir = './models/ldm/stable-diffusion-v1/'
|
||||
Root_dir = '.'
|
||||
Model_dir = 'models'
|
||||
Weights_dir = 'ldm/stable-diffusion-v1/'
|
||||
Default_config_file = './configs/models.yaml'
|
||||
SD_Configs = './configs/stable-diffusion'
|
||||
Datasets = {
|
||||
@ -253,14 +255,15 @@ This involves a few easy steps.
|
||||
# look for legacy model.ckpt in models directory and offer to
|
||||
# normalize its name
|
||||
def migrate_models_ckpt():
|
||||
if not os.path.exists(os.path.join(Model_dir,'model.ckpt')):
|
||||
model_path = os.path.join(Root_dir,Model_dir,Weights_dir)
|
||||
if not os.path.exists(os.path.join(model_path,'model.ckpt')):
|
||||
return
|
||||
new_name = Datasets['stable-diffusion-1.4']['file']
|
||||
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
||||
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
||||
if rename:
|
||||
print(f'model.ckpt => {new_name}')
|
||||
os.rename(os.path.join(Model_dir,'model.ckpt'),os.path.join(Model_dir,new_name))
|
||||
os.rename(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
|
||||
|
||||
#---------------------------------------------
|
||||
def download_weight_datasets(models:dict, access_token:str):
|
||||
@ -271,6 +274,7 @@ def download_weight_datasets(models:dict, access_token:str):
|
||||
filename = Datasets[mod]['file']
|
||||
success = download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=os.path.join(Root_dir,Model_dir,Weights_dir),
|
||||
model_name=filename,
|
||||
access_token=access_token
|
||||
)
|
||||
@ -290,12 +294,12 @@ def download_weight_datasets(models:dict, access_token:str):
|
||||
return successful
|
||||
|
||||
#---------------------------------------------
|
||||
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
||||
model_dest = os.path.join(Model_dir, model_name)
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
||||
model_dest = os.path.join(model_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
url = hf_hub_url(repo_id, model_name)
|
||||
|
||||
header = {"Authorization": f'Bearer {access_token}'}
|
||||
header = {"Authorization": f'Bearer {access_token}'} if access_token else {}
|
||||
open_mode = 'wb'
|
||||
exist_size = 0
|
||||
|
||||
@ -399,30 +403,27 @@ def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str
|
||||
#---------------------------------------------
|
||||
# this will preload the Bert tokenizer fles
|
||||
def download_bert():
|
||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
||||
sys.stdout.flush()
|
||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
||||
print('...success')
|
||||
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#---------------------------------------------
|
||||
# this will download requirements for Kornia
|
||||
def download_kornia():
|
||||
print('Installing Kornia requirements (ignore deprecation errors)...', end='')
|
||||
sys.stdout.flush()
|
||||
import kornia
|
||||
print('...success')
|
||||
def download_from_hf(model_class:object, model_name:str):
|
||||
return model_class.from_pretrained(model_name,
|
||||
cache_dir=os.path.join(Root_dir,Model_dir,model_name),
|
||||
resume_download=True
|
||||
)
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clip():
|
||||
print('Loading CLIP model (ignore deprecation errors)...',end='')
|
||||
sys.stdout.flush()
|
||||
print('Loading CLIP model (ignore deprecation errors)...',end='',file=sys.stderr)
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
transformer = CLIPTextModel.from_pretrained(version)
|
||||
print('...success')
|
||||
download_from_hf(CLIPTokenizer,version)
|
||||
download_from_hf(CLIPTextModel,version)
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#---------------------------------------------
|
||||
def download_gfpgan():
|
||||
@ -464,7 +465,7 @@ def download_gfpgan():
|
||||
if not os.path.exists(model_dest):
|
||||
print(f'Downloading gfpgan model file {model_url}...',end='')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(model_url,model_dest)
|
||||
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
|
||||
print('...success')
|
||||
except Exception:
|
||||
print('Error loading GFPGAN:')
|
||||
@ -472,18 +473,18 @@ def download_gfpgan():
|
||||
|
||||
#---------------------------------------------
|
||||
def download_codeformer():
|
||||
print('Installing CodeFormer model file...',end='')
|
||||
print('Installing CodeFormer model file...',end='',file=sys.stderr)
|
||||
try:
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
||||
if not os.path.exists(model_dest):
|
||||
print('Downloading codeformer model file...')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(model_url,model_dest)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = os.path.join(Root_dir,'models/codeformer/codeformer.pth')
|
||||
if not os.path.exists(model_dest):
|
||||
print('Downloading codeformer model file...')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
|
||||
except Exception:
|
||||
print('Error loading CodeFormer:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clipseg():
|
||||
@ -497,7 +498,7 @@ def download_clipseg():
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'):
|
||||
request.urlretrieve(model_url,weights_zip)
|
||||
request.urlretrieve(model_url,weights_zip,ProgressBar(os.path.basename(model_dest)))
|
||||
with zipfile.ZipFile(weights_zip,'r') as zip:
|
||||
zip.extractall('models/clipseg')
|
||||
os.remove(weights_zip)
|
||||
@ -519,7 +520,7 @@ def download_clipseg():
|
||||
|
||||
#-------------------------------------
|
||||
def download_safety_checker():
|
||||
print('Installing safety model for NSFW content detection...',end='')
|
||||
print('Installing safety model for NSFW content detection...',end='',file=sys.stderr)
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
@ -528,9 +529,25 @@ def download_safety_checker():
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
||||
print('...success')
|
||||
download_from_hf(AutoFeatureExtractor,safety_model_id)
|
||||
download_from_hf(StableDiffusionSafetyChecker,safety_model_id)
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#-------------------------------------
|
||||
class ProgressBar():
|
||||
def __init__(self,model_name='file'):
|
||||
self.pbar = None
|
||||
self.name = model_name
|
||||
|
||||
def __call__(self, block_num, block_size, total_size):
|
||||
if not self.pbar:
|
||||
self.pbar=tqdm(desc=self.name,
|
||||
initial=0,
|
||||
unit='iB',
|
||||
unit_scale=True,
|
||||
total=total_size)
|
||||
downloaded = block_num * block_size
|
||||
self.pbar.update(downloaded)
|
||||
|
||||
#-------------------------------------
|
||||
if __name__ == '__main__':
|
||||
@ -540,13 +557,26 @@ if __name__ == '__main__':
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='run in interactive mode (default)')
|
||||
parser.add_argument('--yes','-y',
|
||||
dest='yes_to_all',
|
||||
action=argparse.BooleanOptionalAction,
|
||||
default=True,
|
||||
help='run in interactive mode (default)')
|
||||
parser.add_argument('--config_file',
|
||||
'-c',
|
||||
dest='config_file',
|
||||
type=str,
|
||||
default='./configs/models.yaml',
|
||||
help='path to configuration file to create')
|
||||
parser.add_argument('--root',
|
||||
dest='root',
|
||||
type=str,
|
||||
default='.',
|
||||
help='path to root of install directory')
|
||||
opt = parser.parse_args()
|
||||
|
||||
# setting a global here
|
||||
Root_dir = os.path.expanduser(opt.root)
|
||||
|
||||
try:
|
||||
if opt.interactive:
|
||||
@ -565,7 +595,6 @@ if __name__ == '__main__':
|
||||
update_config_file(successfully_downloaded,opt)
|
||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||
download_bert()
|
||||
download_kornia()
|
||||
download_clip()
|
||||
download_gfpgan()
|
||||
download_codeformer()
|
||||
|
Loading…
Reference in New Issue
Block a user