model paths fixed, codeformer needs attention

This commit is contained in:
Lincoln Stein 2022-11-15 01:53:10 +00:00
parent 4c035ad4ae
commit 274b276133
11 changed files with 180 additions and 75 deletions

View File

@ -27,6 +27,7 @@ from pytorch_lightning import seed_everything, logging
from ldm.invoke.prompt_parser import PromptParser from ldm.invoke.prompt_parser import PromptParser
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.invoke.globals import Globals
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
@ -220,8 +221,14 @@ class Generate:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True) self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True) local_files_only=True,
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
)
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id,
local_files_only=True,
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
)
self.safety_checker.to(self.device) self.safety_checker.to(self.device)
except Exception: except Exception:
print('** An error was encountered while installing the safety checker:') print('** An error was encountered while installing the safety checker:')

View File

@ -389,6 +389,11 @@ class Args(object):
deprecated_group.add_argument('--laion400m') deprecated_group.add_argument('--laion400m')
deprecated_group.add_argument('--weights') # deprecated deprecated_group.add_argument('--weights') # deprecated
model_group.add_argument(
'--root_dir',
default='.',
help='Path to directory containing "models", "outputs" and "configs"'
)
model_group.add_argument( model_group.add_argument(
'--config', '--config',
'-c', '-c',

15
ldm/invoke/globals.py Normal file
View File

@ -0,0 +1,15 @@
'''
ldm.invoke.globals defines a small number of global variables that would
otherwise have to be passed through long and complex call chains.
It defines a Namespace object named "Globals" that contains
the attributes:
- root - the root directory under which "models" and "outputs" can be found
'''
from argparse import Namespace
Globals = Namespace()
Globals.root = '.'

View File

@ -19,6 +19,7 @@ from sys import getrefcount
from omegaconf import OmegaConf from omegaconf import OmegaConf
from omegaconf.errors import ConfigAttributeError from omegaconf.errors import ConfigAttributeError
from ldm.util import instantiate_from_config from ldm.util import instantiate_from_config
from ldm.invoke.globals import Globals
DEFAULT_MAX_MODELS=2 DEFAULT_MAX_MODELS=2
@ -200,6 +201,9 @@ class ModelCache(object):
width = mconfig.width width = mconfig.width
height = mconfig.height height = mconfig.height
if not os.path.isabs(weights):
weights = os.path.normpath(os.path.join(Globals.root,weights))
print(f'>> Loading {model_name} from {weights}') print(f'>> Loading {model_name} from {weights}')
# for usage statistics # for usage statistics
@ -210,6 +214,8 @@ class ModelCache(object):
tic = time.time() tic = time.time()
# this does the work # this does the work
if not os.path.isabs(config):
config = os.path.join(Globals.root,config)
c = OmegaConf.load(config) c = OmegaConf.load(config)
with open(weights,'rb') as f: with open(weights,'rb') as f:
weight_bytes = f.read() weight_bytes = f.read()
@ -228,6 +234,8 @@ class ModelCache(object):
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py # look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
if vae: if vae:
if not os.path.isabs(vae):
vae = os.path.normpath(os.path.join(Globals.root,vae))
if os.path.exists(vae): if os.path.exists(vae):
print(f' | Loading VAE weights from: {vae}') print(f' | Loading VAE weights from: {vae}')
vae_ckpt = torch.load(vae, map_location="cpu") vae_ckpt = torch.load(vae, map_location="cpu")

View File

@ -3,13 +3,18 @@ import torch
import numpy as np import numpy as np
import warnings import warnings
import sys import sys
from ldm.invoke.globals import Globals
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
class CodeFormerRestoration(): class CodeFormerRestoration():
def __init__(self, def __init__(self,
codeformer_dir='ldm/invoke/restoration/codeformer', codeformer_dir='models/codeformer',
codeformer_model_path='weights/codeformer.pth') -> None: codeformer_model_path='codeformer.pth') -> None:
if not os.path.isabs(codeformer_dir):
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
self.model_path = os.path.join(codeformer_dir, codeformer_model_path) self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path) self.codeformer_model_exists = os.path.isfile(self.model_path)
@ -33,9 +38,20 @@ class CodeFormerRestoration():
cf_class = CodeFormer cf_class = CodeFormer
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device) cf = cf_class(
dim_embd=512,
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/invoke/restoration/codeformer/weights'), progress=True) codebook_size=1024,
n_head=8,
n_layers=9,
connect_list=['32', '64', '128', '256']
).to(device)
# note that this file should already be downloaded and cached at
# this point
checkpoint_path = load_file_from_url(url=pretrained_model_url,
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
progress=True
)
checkpoint = torch.load(checkpoint_path)['params_ema'] checkpoint = torch.load(checkpoint_path)['params_ema']
cf.load_state_dict(checkpoint) cf.load_state_dict(checkpoint)
cf.eval() cf.eval()

View File

@ -3,6 +3,7 @@ import warnings
import os import os
import sys import sys
import numpy as np import numpy as np
from ldm.invoke.globals import Globals
from PIL import Image from PIL import Image
@ -10,18 +11,21 @@ from PIL import Image
class GFPGAN(): class GFPGAN():
def __init__( def __init__(
self, self,
gfpgan_model_path='./models/gfpgan/GFPGANv1.4.pth') -> None: gfpgan_model_path='models/gfpgan/GFPGANv1.4.pth'
) -> None:
self.model_path = os.path.join(gfpgan_model_path)
if not os.path.isabs(gfpgan_model_path):
gfpgan_model_path=os.path.join(Globals.root,gfpgan_model_path)
self.model_path = gfpgan_model_path
self.gfpgan_model_exists = os.path.isfile(self.model_path) self.gfpgan_model_exists = os.path.isfile(self.model_path)
if not self.gfpgan_model_exists: if not self.gfpgan_model_exists:
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path) print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
return None return None
def model_exists(self): def model_exists(self):
return os.path.isfile(self.model_path) return os.path.isfile(self.model_path)
def process(self, image, strength: float, seed: str = None): def process(self, image, strength: float, seed: str = None):
if seed is not None: if seed is not None:
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}') print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')

View File

@ -1,5 +1,5 @@
import math import math
import os.path
import torch import torch
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
@ -8,6 +8,7 @@ from einops import rearrange, repeat
from transformers import CLIPTokenizer, CLIPTextModel from transformers import CLIPTokenizer, CLIPTextModel
import kornia import kornia
from ldm.invoke.devices import choose_torch_device from ldm.invoke.devices import choose_torch_device
from ldm.invoke.globals import Globals
from ldm.modules.x_transformer import ( from ldm.modules.x_transformer import (
Encoder, Encoder,
@ -98,21 +99,19 @@ class BERTTokenizer(AbstractEncoder):
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" """Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__( def __init__(
self, device=choose_torch_device(), vq_interface=True, max_length=77 self, device=choose_torch_device(), vq_interface=True, max_length=77
): ):
super().__init__() super().__init__()
from transformers import ( from transformers import (
BertTokenizerFast, BertTokenizerFast,
) # TODO: add to reuquirements )
# Modified to allow to run on non-internet connected compute nodes. cache = os.path.join(Globals.root,'models/bert-base-uncased')
# Model needs to be loaded into cache from an internet-connected machine
# by running:
# from transformers import BertTokenizerFast
# BertTokenizerFast.from_pretrained("bert-base-uncased")
try: try:
self.tokenizer = BertTokenizerFast.from_pretrained( self.tokenizer = BertTokenizerFast.from_pretrained(
'bert-base-uncased', local_files_only=True 'bert-base-uncased',
cache_dir=cache,
local_files_only=True
) )
except OSError: except OSError:
raise SystemExit( raise SystemExit(
@ -150,14 +149,14 @@ class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers""" """Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__( def __init__(
self, self,
n_embed, n_embed,
n_layer, n_layer,
vocab_size=30522, vocab_size=30522,
max_seq_len=77, max_seq_len=77,
device=choose_torch_device(), device=choose_torch_device(),
use_tokenizer=True, use_tokenizer=True,
embedding_dropout=0.0, embedding_dropout=0.0,
): ):
super().__init__() super().__init__()
self.use_tknz_fn = use_tokenizer self.use_tknz_fn = use_tokenizer
@ -245,10 +244,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
): ):
super().__init__() super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained( self.tokenizer = CLIPTokenizer.from_pretrained(
version, local_files_only=True version,
cache_dir=os.path.join(Globals.root,'models',version),
local_files_only=True
) )
self.transformer = CLIPTextModel.from_pretrained( self.transformer = CLIPTextModel.from_pretrained(
version, local_files_only=True version,
cache_dir=os.path.join(Globals.root,'models',version),
local_files_only=True
) )
self.device = device self.device = device
self.max_length = max_length self.max_length = max_length

View File

@ -1,2 +0,0 @@
See docs/features/INSTALLING_MODELS.md for how to populate this
directory with one or more Stable Diffusion model weight files.

View File

@ -11,9 +11,9 @@ import time
import traceback import traceback
import yaml import yaml
from ldm.invoke.prompt_parser import PromptParser
sys.path.append('.') # corrects a weird problem on Macs sys.path.append('.') # corrects a weird problem on Macs
from ldm.invoke.globals import Globals
from ldm.invoke.prompt_parser import PromptParser
from ldm.invoke.readline import get_completer from ldm.invoke.readline import get_completer
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
@ -47,6 +47,9 @@ def main():
print('--max_loaded_models must be >= 1; using 1') print('--max_loaded_models must be >= 1; using 1')
args.max_loaded_models = 1 args.max_loaded_models = 1
# alert - setting a global here
Globals.root=os.path.expanduser(args.root_dir)
from ldm.generate import Generate from ldm.generate import Generate
# these two lines prevent a horrible warning message from appearing # these two lines prevent a horrible warning message from appearing
@ -57,10 +60,16 @@ def main():
# Loading Face Restoration and ESRGAN Modules # Loading Face Restoration and ESRGAN Modules
gfpgan,codeformer,esrgan = load_face_restoration(opt) gfpgan,codeformer,esrgan = load_face_restoration(opt)
# make sure the output directory exists # normalize the outdir relative to root and make sure it exists
if not os.path.abspath(opt.outdir):
opt.outdir=os.path.normpath(os.path.join(Globals.root,opt.outdir))
if not os.path.exists(opt.outdir): if not os.path.exists(opt.outdir):
os.makedirs(opt.outdir) os.makedirs(opt.outdir)
# normalize the config directory relative to root
if not os.path.abspath(opt.conf):
opt.conf=os.path.normpath(os.path.join(Globals.root,opt.conf))
# load the infile as a list of lines # load the infile as a list of lines
if opt.infile: if opt.infile:
try: try:
@ -77,7 +86,7 @@ def main():
# creating a Generate object: # creating a Generate object:
try: try:
gen = Generate( gen = Generate(
conf = opt.conf, conf = os.path.join(Globals.root,opt.conf),
model = opt.model, model = opt.model,
sampler_name = opt.sampler_name, sampler_name = opt.sampler_name,
embedding_path = opt.embedding_path, embedding_path = opt.embedding_path,
@ -128,6 +137,8 @@ def main_loop(gen, opt):
doneAfterInFile = infile is not None doneAfterInFile = infile is not None
path_filter = re.compile(r'[<>:"/\\|?*]') path_filter = re.compile(r'[<>:"/\\|?*]')
last_results = list() last_results = list()
if not os.path.isabs(opt.conf):
opt.conf = os.path.join(Globals.root,opt.conf)
model_config = OmegaConf.load(opt.conf) model_config = OmegaConf.load(opt.conf)
# The readline completer reads history from the .dream_history file located in the # The readline completer reads history from the .dream_history file located in the

View File

@ -1,5 +1,6 @@
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
from ldm.modules.embedding_manager import EmbeddingManager from ldm.modules.embedding_manager import EmbeddingManager
from ldm.modules.globals import Globals
import argparse, os import argparse, os
from functools import partial from functools import partial
@ -51,6 +52,13 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
"--root_dir",
type=str,
default='.',
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'."
)
parser.add_argument( parser.add_argument(
"--manager_ckpts", "--manager_ckpts",
type=str, type=str,
@ -73,6 +81,7 @@ if __name__ == "__main__":
) )
args = parser.parse_args() args = parser.parse_args()
Globals.root=args.root_dir
if args.use_bert: if args.use_bert:
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda() embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()

103
scripts/preload_models.py Normal file → Executable file
View File

@ -32,7 +32,9 @@ warnings.filterwarnings('ignore')
#warnings.filterwarnings('ignore',category=UserWarning) #warnings.filterwarnings('ignore',category=UserWarning)
#--------------------------globals-- #--------------------------globals--
Model_dir = './models/ldm/stable-diffusion-v1/' Root_dir = '.'
Model_dir = 'models'
Weights_dir = 'ldm/stable-diffusion-v1/'
Default_config_file = './configs/models.yaml' Default_config_file = './configs/models.yaml'
SD_Configs = './configs/stable-diffusion' SD_Configs = './configs/stable-diffusion'
Datasets = { Datasets = {
@ -253,14 +255,15 @@ This involves a few easy steps.
# look for legacy model.ckpt in models directory and offer to # look for legacy model.ckpt in models directory and offer to
# normalize its name # normalize its name
def migrate_models_ckpt(): def migrate_models_ckpt():
if not os.path.exists(os.path.join(Model_dir,'model.ckpt')): model_path = os.path.join(Root_dir,Model_dir,Weights_dir)
if not os.path.exists(os.path.join(model_path,'model.ckpt')):
return return
new_name = Datasets['stable-diffusion-1.4']['file'] new_name = Datasets['stable-diffusion-1.4']['file']
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.') print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?') rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
if rename: if rename:
print(f'model.ckpt => {new_name}') print(f'model.ckpt => {new_name}')
os.rename(os.path.join(Model_dir,'model.ckpt'),os.path.join(Model_dir,new_name)) os.rename(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
#--------------------------------------------- #---------------------------------------------
def download_weight_datasets(models:dict, access_token:str): def download_weight_datasets(models:dict, access_token:str):
@ -271,6 +274,7 @@ def download_weight_datasets(models:dict, access_token:str):
filename = Datasets[mod]['file'] filename = Datasets[mod]['file']
success = download_with_resume( success = download_with_resume(
repo_id=repo_id, repo_id=repo_id,
model_dir=os.path.join(Root_dir,Model_dir,Weights_dir),
model_name=filename, model_name=filename,
access_token=access_token access_token=access_token
) )
@ -290,12 +294,12 @@ def download_weight_datasets(models:dict, access_token:str):
return successful return successful
#--------------------------------------------- #---------------------------------------------
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool: def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
model_dest = os.path.join(Model_dir, model_name) model_dest = os.path.join(model_dir, model_name)
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(model_dir, exist_ok=True)
url = hf_hub_url(repo_id, model_name) url = hf_hub_url(repo_id, model_name)
header = {"Authorization": f'Bearer {access_token}'} header = {"Authorization": f'Bearer {access_token}'} if access_token else {}
open_mode = 'wb' open_mode = 'wb'
exist_size = 0 exist_size = 0
@ -399,30 +403,27 @@ def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str
#--------------------------------------------- #---------------------------------------------
# this will preload the Bert tokenizer fles # this will preload the Bert tokenizer fles
def download_bert(): def download_bert():
print('Installing bert tokenizer (ignore deprecation errors)...', end='') print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
sys.stdout.flush()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
from transformers import BertTokenizerFast, AutoFeatureExtractor from transformers import BertTokenizerFast, AutoFeatureExtractor
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') download_from_hf(BertTokenizerFast,'bert-base-uncased')
print('...success') print('...success',file=sys.stderr)
#--------------------------------------------- #---------------------------------------------
# this will download requirements for Kornia def download_from_hf(model_class:object, model_name:str):
def download_kornia(): return model_class.from_pretrained(model_name,
print('Installing Kornia requirements (ignore deprecation errors)...', end='') cache_dir=os.path.join(Root_dir,Model_dir,model_name),
sys.stdout.flush() resume_download=True
import kornia )
print('...success')
#--------------------------------------------- #---------------------------------------------
def download_clip(): def download_clip():
print('Loading CLIP model (ignore deprecation errors)...',end='') print('Loading CLIP model (ignore deprecation errors)...',end='',file=sys.stderr)
sys.stdout.flush()
version = 'openai/clip-vit-large-patch14' version = 'openai/clip-vit-large-patch14'
tokenizer = CLIPTokenizer.from_pretrained(version) download_from_hf(CLIPTokenizer,version)
transformer = CLIPTextModel.from_pretrained(version) download_from_hf(CLIPTextModel,version)
print('...success') print('...success',file=sys.stderr)
#--------------------------------------------- #---------------------------------------------
def download_gfpgan(): def download_gfpgan():
@ -464,7 +465,7 @@ def download_gfpgan():
if not os.path.exists(model_dest): if not os.path.exists(model_dest):
print(f'Downloading gfpgan model file {model_url}...',end='') print(f'Downloading gfpgan model file {model_url}...',end='')
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve(model_url,model_dest) request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
print('...success') print('...success')
except Exception: except Exception:
print('Error loading GFPGAN:') print('Error loading GFPGAN:')
@ -472,18 +473,18 @@ def download_gfpgan():
#--------------------------------------------- #---------------------------------------------
def download_codeformer(): def download_codeformer():
print('Installing CodeFormer model file...',end='') print('Installing CodeFormer model file...',end='',file=sys.stderr)
try: try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth' model_dest = os.path.join(Root_dir,'models/codeformer/codeformer.pth')
if not os.path.exists(model_dest): if not os.path.exists(model_dest):
print('Downloading codeformer model file...') print('Downloading codeformer model file...')
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve(model_url,model_dest) request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
except Exception: except Exception:
print('Error loading CodeFormer:') print('Error loading CodeFormer:')
print(traceback.format_exc()) print(traceback.format_exc())
print('...success') print('...success',file=sys.stderr)
#--------------------------------------------- #---------------------------------------------
def download_clipseg(): def download_clipseg():
@ -497,7 +498,7 @@ def download_clipseg():
if not os.path.exists(model_dest): if not os.path.exists(model_dest):
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'): if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'):
request.urlretrieve(model_url,weights_zip) request.urlretrieve(model_url,weights_zip,ProgressBar(os.path.basename(model_dest)))
with zipfile.ZipFile(weights_zip,'r') as zip: with zipfile.ZipFile(weights_zip,'r') as zip:
zip.extractall('models/clipseg') zip.extractall('models/clipseg')
os.remove(weights_zip) os.remove(weights_zip)
@ -519,7 +520,7 @@ def download_clipseg():
#------------------------------------- #-------------------------------------
def download_safety_checker(): def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='') print('Installing safety model for NSFW content detection...',end='',file=sys.stderr)
try: try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
@ -528,9 +529,25 @@ def download_safety_checker():
print(traceback.format_exc()) print(traceback.format_exc())
return return
safety_model_id = "CompVis/stable-diffusion-safety-checker" safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id) download_from_hf(AutoFeatureExtractor,safety_model_id)
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id) download_from_hf(StableDiffusionSafetyChecker,safety_model_id)
print('...success') print('...success',file=sys.stderr)
#-------------------------------------
class ProgressBar():
def __init__(self,model_name='file'):
self.pbar = None
self.name = model_name
def __call__(self, block_num, block_size, total_size):
if not self.pbar:
self.pbar=tqdm(desc=self.name,
initial=0,
unit='iB',
unit_scale=True,
total=total_size)
downloaded = block_num * block_size
self.pbar.update(downloaded)
#------------------------------------- #-------------------------------------
if __name__ == '__main__': if __name__ == '__main__':
@ -540,13 +557,26 @@ if __name__ == '__main__':
action=argparse.BooleanOptionalAction, action=argparse.BooleanOptionalAction,
default=True, default=True,
help='run in interactive mode (default)') help='run in interactive mode (default)')
parser.add_argument('--yes','-y',
dest='yes_to_all',
action=argparse.BooleanOptionalAction,
default=True,
help='run in interactive mode (default)')
parser.add_argument('--config_file', parser.add_argument('--config_file',
'-c', '-c',
dest='config_file', dest='config_file',
type=str, type=str,
default='./configs/models.yaml', default='./configs/models.yaml',
help='path to configuration file to create') help='path to configuration file to create')
parser.add_argument('--root',
dest='root',
type=str,
default='.',
help='path to root of install directory')
opt = parser.parse_args() opt = parser.parse_args()
# setting a global here
Root_dir = os.path.expanduser(opt.root)
try: try:
if opt.interactive: if opt.interactive:
@ -565,7 +595,6 @@ if __name__ == '__main__':
update_config_file(successfully_downloaded,opt) update_config_file(successfully_downloaded,opt)
print('\n** DOWNLOADING SUPPORT MODELS **') print('\n** DOWNLOADING SUPPORT MODELS **')
download_bert() download_bert()
download_kornia()
download_clip() download_clip()
download_gfpgan() download_gfpgan()
download_codeformer() download_codeformer()