mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
model paths fixed, codeformer needs attention
This commit is contained in:
parent
4c035ad4ae
commit
274b276133
@ -27,6 +27,7 @@ from pytorch_lightning import seed_everything, logging
|
|||||||
|
|
||||||
from ldm.invoke.prompt_parser import PromptParser
|
from ldm.invoke.prompt_parser import PromptParser
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
from ldm.models.diffusion.ddim import DDIMSampler
|
from ldm.models.diffusion.ddim import DDIMSampler
|
||||||
from ldm.models.diffusion.plms import PLMSSampler
|
from ldm.models.diffusion.plms import PLMSSampler
|
||||||
from ldm.models.diffusion.ksampler import KSampler
|
from ldm.models.diffusion.ksampler import KSampler
|
||||||
@ -220,8 +221,14 @@ class Generate:
|
|||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id, local_files_only=True)
|
self.safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id,
|
||||||
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id, local_files_only=True)
|
local_files_only=True,
|
||||||
|
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
|
||||||
|
)
|
||||||
|
self.safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id,
|
||||||
|
local_files_only=True,
|
||||||
|
cache_dir=os.path.join(Globals.root,'models',safety_model_id)
|
||||||
|
)
|
||||||
self.safety_checker.to(self.device)
|
self.safety_checker.to(self.device)
|
||||||
except Exception:
|
except Exception:
|
||||||
print('** An error was encountered while installing the safety checker:')
|
print('** An error was encountered while installing the safety checker:')
|
||||||
|
@ -389,6 +389,11 @@ class Args(object):
|
|||||||
|
|
||||||
deprecated_group.add_argument('--laion400m')
|
deprecated_group.add_argument('--laion400m')
|
||||||
deprecated_group.add_argument('--weights') # deprecated
|
deprecated_group.add_argument('--weights') # deprecated
|
||||||
|
model_group.add_argument(
|
||||||
|
'--root_dir',
|
||||||
|
default='.',
|
||||||
|
help='Path to directory containing "models", "outputs" and "configs"'
|
||||||
|
)
|
||||||
model_group.add_argument(
|
model_group.add_argument(
|
||||||
'--config',
|
'--config',
|
||||||
'-c',
|
'-c',
|
||||||
|
15
ldm/invoke/globals.py
Normal file
15
ldm/invoke/globals.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
'''
|
||||||
|
ldm.invoke.globals defines a small number of global variables that would
|
||||||
|
otherwise have to be passed through long and complex call chains.
|
||||||
|
|
||||||
|
It defines a Namespace object named "Globals" that contains
|
||||||
|
the attributes:
|
||||||
|
|
||||||
|
- root - the root directory under which "models" and "outputs" can be found
|
||||||
|
'''
|
||||||
|
|
||||||
|
from argparse import Namespace
|
||||||
|
|
||||||
|
Globals = Namespace()
|
||||||
|
Globals.root = '.'
|
||||||
|
|
@ -19,6 +19,7 @@ from sys import getrefcount
|
|||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from omegaconf.errors import ConfigAttributeError
|
from omegaconf.errors import ConfigAttributeError
|
||||||
from ldm.util import instantiate_from_config
|
from ldm.util import instantiate_from_config
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
DEFAULT_MAX_MODELS=2
|
DEFAULT_MAX_MODELS=2
|
||||||
|
|
||||||
@ -200,6 +201,9 @@ class ModelCache(object):
|
|||||||
width = mconfig.width
|
width = mconfig.width
|
||||||
height = mconfig.height
|
height = mconfig.height
|
||||||
|
|
||||||
|
if not os.path.isabs(weights):
|
||||||
|
weights = os.path.normpath(os.path.join(Globals.root,weights))
|
||||||
|
|
||||||
print(f'>> Loading {model_name} from {weights}')
|
print(f'>> Loading {model_name} from {weights}')
|
||||||
|
|
||||||
# for usage statistics
|
# for usage statistics
|
||||||
@ -210,6 +214,8 @@ class ModelCache(object):
|
|||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
# this does the work
|
# this does the work
|
||||||
|
if not os.path.isabs(config):
|
||||||
|
config = os.path.join(Globals.root,config)
|
||||||
c = OmegaConf.load(config)
|
c = OmegaConf.load(config)
|
||||||
with open(weights,'rb') as f:
|
with open(weights,'rb') as f:
|
||||||
weight_bytes = f.read()
|
weight_bytes = f.read()
|
||||||
@ -228,6 +234,8 @@ class ModelCache(object):
|
|||||||
|
|
||||||
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
# look and load a matching vae file. Code borrowed from AUTOMATIC1111 modules/sd_models.py
|
||||||
if vae:
|
if vae:
|
||||||
|
if not os.path.isabs(vae):
|
||||||
|
vae = os.path.normpath(os.path.join(Globals.root,vae))
|
||||||
if os.path.exists(vae):
|
if os.path.exists(vae):
|
||||||
print(f' | Loading VAE weights from: {vae}')
|
print(f' | Loading VAE weights from: {vae}')
|
||||||
vae_ckpt = torch.load(vae, map_location="cpu")
|
vae_ckpt = torch.load(vae, map_location="cpu")
|
||||||
|
@ -3,13 +3,18 @@ import torch
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import warnings
|
import warnings
|
||||||
import sys
|
import sys
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
|
|
||||||
class CodeFormerRestoration():
|
class CodeFormerRestoration():
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
codeformer_dir='ldm/invoke/restoration/codeformer',
|
codeformer_dir='models/codeformer',
|
||||||
codeformer_model_path='weights/codeformer.pth') -> None:
|
codeformer_model_path='codeformer.pth') -> None:
|
||||||
|
|
||||||
|
if not os.path.isabs(codeformer_dir):
|
||||||
|
codeformer_dir = os.path.join(Globals.root, codeformer_dir)
|
||||||
|
|
||||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
@ -33,9 +38,20 @@ class CodeFormerRestoration():
|
|||||||
|
|
||||||
cf_class = CodeFormer
|
cf_class = CodeFormer
|
||||||
|
|
||||||
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
|
cf = cf_class(
|
||||||
|
dim_embd=512,
|
||||||
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/invoke/restoration/codeformer/weights'), progress=True)
|
codebook_size=1024,
|
||||||
|
n_head=8,
|
||||||
|
n_layers=9,
|
||||||
|
connect_list=['32', '64', '128', '256']
|
||||||
|
).to(device)
|
||||||
|
|
||||||
|
# note that this file should already be downloaded and cached at
|
||||||
|
# this point
|
||||||
|
checkpoint_path = load_file_from_url(url=pretrained_model_url,
|
||||||
|
model_dir=os.path.abspath(os.path.dirname(self.model_path)),
|
||||||
|
progress=True
|
||||||
|
)
|
||||||
checkpoint = torch.load(checkpoint_path)['params_ema']
|
checkpoint = torch.load(checkpoint_path)['params_ema']
|
||||||
cf.load_state_dict(checkpoint)
|
cf.load_state_dict(checkpoint)
|
||||||
cf.eval()
|
cf.eval()
|
||||||
|
@ -3,6 +3,7 @@ import warnings
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -10,18 +11,21 @@ from PIL import Image
|
|||||||
class GFPGAN():
|
class GFPGAN():
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
gfpgan_model_path='./models/gfpgan/GFPGANv1.4.pth') -> None:
|
gfpgan_model_path='models/gfpgan/GFPGANv1.4.pth'
|
||||||
|
) -> None:
|
||||||
self.model_path = os.path.join(gfpgan_model_path)
|
|
||||||
|
if not os.path.isabs(gfpgan_model_path):
|
||||||
|
gfpgan_model_path=os.path.join(Globals.root,gfpgan_model_path)
|
||||||
|
self.model_path = gfpgan_model_path
|
||||||
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
self.gfpgan_model_exists = os.path.isfile(self.model_path)
|
||||||
|
|
||||||
if not self.gfpgan_model_exists:
|
if not self.gfpgan_model_exists:
|
||||||
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
|
print('## NOT FOUND: GFPGAN model not found at ' + self.model_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def model_exists(self):
|
def model_exists(self):
|
||||||
return os.path.isfile(self.model_path)
|
return os.path.isfile(self.model_path)
|
||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None):
|
def process(self, image, strength: float, seed: str = None):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import math
|
import math
|
||||||
|
import os.path
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -8,6 +8,7 @@ from einops import rearrange, repeat
|
|||||||
from transformers import CLIPTokenizer, CLIPTextModel
|
from transformers import CLIPTokenizer, CLIPTextModel
|
||||||
import kornia
|
import kornia
|
||||||
from ldm.invoke.devices import choose_torch_device
|
from ldm.invoke.devices import choose_torch_device
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
|
||||||
from ldm.modules.x_transformer import (
|
from ldm.modules.x_transformer import (
|
||||||
Encoder,
|
Encoder,
|
||||||
@ -98,21 +99,19 @@ class BERTTokenizer(AbstractEncoder):
|
|||||||
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
"""Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, device=choose_torch_device(), vq_interface=True, max_length=77
|
self, device=choose_torch_device(), vq_interface=True, max_length=77
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
from transformers import (
|
from transformers import (
|
||||||
BertTokenizerFast,
|
BertTokenizerFast,
|
||||||
) # TODO: add to reuquirements
|
)
|
||||||
|
|
||||||
# Modified to allow to run on non-internet connected compute nodes.
|
cache = os.path.join(Globals.root,'models/bert-base-uncased')
|
||||||
# Model needs to be loaded into cache from an internet-connected machine
|
|
||||||
# by running:
|
|
||||||
# from transformers import BertTokenizerFast
|
|
||||||
# BertTokenizerFast.from_pretrained("bert-base-uncased")
|
|
||||||
try:
|
try:
|
||||||
self.tokenizer = BertTokenizerFast.from_pretrained(
|
self.tokenizer = BertTokenizerFast.from_pretrained(
|
||||||
'bert-base-uncased', local_files_only=True
|
'bert-base-uncased',
|
||||||
|
cache_dir=cache,
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
raise SystemExit(
|
raise SystemExit(
|
||||||
@ -150,14 +149,14 @@ class BERTEmbedder(AbstractEncoder):
|
|||||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_embed,
|
n_embed,
|
||||||
n_layer,
|
n_layer,
|
||||||
vocab_size=30522,
|
vocab_size=30522,
|
||||||
max_seq_len=77,
|
max_seq_len=77,
|
||||||
device=choose_torch_device(),
|
device=choose_torch_device(),
|
||||||
use_tokenizer=True,
|
use_tokenizer=True,
|
||||||
embedding_dropout=0.0,
|
embedding_dropout=0.0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_tknz_fn = use_tokenizer
|
self.use_tknz_fn = use_tokenizer
|
||||||
@ -245,10 +244,14 @@ class FrozenCLIPEmbedder(AbstractEncoder):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.tokenizer = CLIPTokenizer.from_pretrained(
|
self.tokenizer = CLIPTokenizer.from_pretrained(
|
||||||
version, local_files_only=True
|
version,
|
||||||
|
cache_dir=os.path.join(Globals.root,'models',version),
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.transformer = CLIPTextModel.from_pretrained(
|
self.transformer = CLIPTextModel.from_pretrained(
|
||||||
version, local_files_only=True
|
version,
|
||||||
|
cache_dir=os.path.join(Globals.root,'models',version),
|
||||||
|
local_files_only=True
|
||||||
)
|
)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
|
@ -1,2 +0,0 @@
|
|||||||
See docs/features/INSTALLING_MODELS.md for how to populate this
|
|
||||||
directory with one or more Stable Diffusion model weight files.
|
|
@ -11,9 +11,9 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
from ldm.invoke.prompt_parser import PromptParser
|
|
||||||
|
|
||||||
sys.path.append('.') # corrects a weird problem on Macs
|
sys.path.append('.') # corrects a weird problem on Macs
|
||||||
|
from ldm.invoke.globals import Globals
|
||||||
|
from ldm.invoke.prompt_parser import PromptParser
|
||||||
from ldm.invoke.readline import get_completer
|
from ldm.invoke.readline import get_completer
|
||||||
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
|
from ldm.invoke.args import Args, metadata_dumps, metadata_from_png, dream_cmd_from_png
|
||||||
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
from ldm.invoke.pngwriter import PngWriter, retrieve_metadata, write_metadata
|
||||||
@ -47,6 +47,9 @@ def main():
|
|||||||
print('--max_loaded_models must be >= 1; using 1')
|
print('--max_loaded_models must be >= 1; using 1')
|
||||||
args.max_loaded_models = 1
|
args.max_loaded_models = 1
|
||||||
|
|
||||||
|
# alert - setting a global here
|
||||||
|
Globals.root=os.path.expanduser(args.root_dir)
|
||||||
|
|
||||||
from ldm.generate import Generate
|
from ldm.generate import Generate
|
||||||
|
|
||||||
# these two lines prevent a horrible warning message from appearing
|
# these two lines prevent a horrible warning message from appearing
|
||||||
@ -57,10 +60,16 @@ def main():
|
|||||||
# Loading Face Restoration and ESRGAN Modules
|
# Loading Face Restoration and ESRGAN Modules
|
||||||
gfpgan,codeformer,esrgan = load_face_restoration(opt)
|
gfpgan,codeformer,esrgan = load_face_restoration(opt)
|
||||||
|
|
||||||
# make sure the output directory exists
|
# normalize the outdir relative to root and make sure it exists
|
||||||
|
if not os.path.abspath(opt.outdir):
|
||||||
|
opt.outdir=os.path.normpath(os.path.join(Globals.root,opt.outdir))
|
||||||
if not os.path.exists(opt.outdir):
|
if not os.path.exists(opt.outdir):
|
||||||
os.makedirs(opt.outdir)
|
os.makedirs(opt.outdir)
|
||||||
|
|
||||||
|
# normalize the config directory relative to root
|
||||||
|
if not os.path.abspath(opt.conf):
|
||||||
|
opt.conf=os.path.normpath(os.path.join(Globals.root,opt.conf))
|
||||||
|
|
||||||
# load the infile as a list of lines
|
# load the infile as a list of lines
|
||||||
if opt.infile:
|
if opt.infile:
|
||||||
try:
|
try:
|
||||||
@ -77,7 +86,7 @@ def main():
|
|||||||
# creating a Generate object:
|
# creating a Generate object:
|
||||||
try:
|
try:
|
||||||
gen = Generate(
|
gen = Generate(
|
||||||
conf = opt.conf,
|
conf = os.path.join(Globals.root,opt.conf),
|
||||||
model = opt.model,
|
model = opt.model,
|
||||||
sampler_name = opt.sampler_name,
|
sampler_name = opt.sampler_name,
|
||||||
embedding_path = opt.embedding_path,
|
embedding_path = opt.embedding_path,
|
||||||
@ -128,6 +137,8 @@ def main_loop(gen, opt):
|
|||||||
doneAfterInFile = infile is not None
|
doneAfterInFile = infile is not None
|
||||||
path_filter = re.compile(r'[<>:"/\\|?*]')
|
path_filter = re.compile(r'[<>:"/\\|?*]')
|
||||||
last_results = list()
|
last_results = list()
|
||||||
|
if not os.path.isabs(opt.conf):
|
||||||
|
opt.conf = os.path.join(Globals.root,opt.conf)
|
||||||
model_config = OmegaConf.load(opt.conf)
|
model_config = OmegaConf.load(opt.conf)
|
||||||
|
|
||||||
# The readline completer reads history from the .dream_history file located in the
|
# The readline completer reads history from the .dream_history file located in the
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
from ldm.modules.encoders.modules import FrozenCLIPEmbedder, BERTEmbedder
|
||||||
from ldm.modules.embedding_manager import EmbeddingManager
|
from ldm.modules.embedding_manager import EmbeddingManager
|
||||||
|
from ldm.modules.globals import Globals
|
||||||
|
|
||||||
import argparse, os
|
import argparse, os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -51,6 +52,13 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--root_dir",
|
||||||
|
type=str,
|
||||||
|
default='.',
|
||||||
|
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--manager_ckpts",
|
"--manager_ckpts",
|
||||||
type=str,
|
type=str,
|
||||||
@ -73,6 +81,7 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
Globals.root=args.root_dir
|
||||||
|
|
||||||
if args.use_bert:
|
if args.use_bert:
|
||||||
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()
|
||||||
|
103
scripts/preload_models.py
Normal file → Executable file
103
scripts/preload_models.py
Normal file → Executable file
@ -32,7 +32,9 @@ warnings.filterwarnings('ignore')
|
|||||||
#warnings.filterwarnings('ignore',category=UserWarning)
|
#warnings.filterwarnings('ignore',category=UserWarning)
|
||||||
|
|
||||||
#--------------------------globals--
|
#--------------------------globals--
|
||||||
Model_dir = './models/ldm/stable-diffusion-v1/'
|
Root_dir = '.'
|
||||||
|
Model_dir = 'models'
|
||||||
|
Weights_dir = 'ldm/stable-diffusion-v1/'
|
||||||
Default_config_file = './configs/models.yaml'
|
Default_config_file = './configs/models.yaml'
|
||||||
SD_Configs = './configs/stable-diffusion'
|
SD_Configs = './configs/stable-diffusion'
|
||||||
Datasets = {
|
Datasets = {
|
||||||
@ -253,14 +255,15 @@ This involves a few easy steps.
|
|||||||
# look for legacy model.ckpt in models directory and offer to
|
# look for legacy model.ckpt in models directory and offer to
|
||||||
# normalize its name
|
# normalize its name
|
||||||
def migrate_models_ckpt():
|
def migrate_models_ckpt():
|
||||||
if not os.path.exists(os.path.join(Model_dir,'model.ckpt')):
|
model_path = os.path.join(Root_dir,Model_dir,Weights_dir)
|
||||||
|
if not os.path.exists(os.path.join(model_path,'model.ckpt')):
|
||||||
return
|
return
|
||||||
new_name = Datasets['stable-diffusion-1.4']['file']
|
new_name = Datasets['stable-diffusion-1.4']['file']
|
||||||
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
print('You seem to have the Stable Diffusion v4.1 "model.ckpt" already installed.')
|
||||||
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
rename = yes_or_no(f'Ok to rename it to "{new_name}" for future reference?')
|
||||||
if rename:
|
if rename:
|
||||||
print(f'model.ckpt => {new_name}')
|
print(f'model.ckpt => {new_name}')
|
||||||
os.rename(os.path.join(Model_dir,'model.ckpt'),os.path.join(Model_dir,new_name))
|
os.rename(os.path.join(model_path,'model.ckpt'),os.path.join(model_path,new_name))
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_weight_datasets(models:dict, access_token:str):
|
def download_weight_datasets(models:dict, access_token:str):
|
||||||
@ -271,6 +274,7 @@ def download_weight_datasets(models:dict, access_token:str):
|
|||||||
filename = Datasets[mod]['file']
|
filename = Datasets[mod]['file']
|
||||||
success = download_with_resume(
|
success = download_with_resume(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
|
model_dir=os.path.join(Root_dir,Model_dir,Weights_dir),
|
||||||
model_name=filename,
|
model_name=filename,
|
||||||
access_token=access_token
|
access_token=access_token
|
||||||
)
|
)
|
||||||
@ -290,12 +294,12 @@ def download_weight_datasets(models:dict, access_token:str):
|
|||||||
return successful
|
return successful
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_with_resume(repo_id:str, model_name:str, access_token:str)->bool:
|
def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
||||||
model_dest = os.path.join(Model_dir, model_name)
|
model_dest = os.path.join(model_dir, model_name)
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(model_dir, exist_ok=True)
|
||||||
url = hf_hub_url(repo_id, model_name)
|
url = hf_hub_url(repo_id, model_name)
|
||||||
|
|
||||||
header = {"Authorization": f'Bearer {access_token}'}
|
header = {"Authorization": f'Bearer {access_token}'} if access_token else {}
|
||||||
open_mode = 'wb'
|
open_mode = 'wb'
|
||||||
exist_size = 0
|
exist_size = 0
|
||||||
|
|
||||||
@ -399,30 +403,27 @@ def new_config_file_contents(successfully_downloaded:dict, Config_file:str)->str
|
|||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
# this will preload the Bert tokenizer fles
|
# this will preload the Bert tokenizer fles
|
||||||
def download_bert():
|
def download_bert():
|
||||||
print('Installing bert tokenizer (ignore deprecation errors)...', end='')
|
print('Installing bert tokenizer (ignore deprecation errors)...', end='',file=sys.stderr)
|
||||||
sys.stdout.flush()
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
from transformers import BertTokenizerFast, AutoFeatureExtractor
|
||||||
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
|
download_from_hf(BertTokenizerFast,'bert-base-uncased')
|
||||||
print('...success')
|
print('...success',file=sys.stderr)
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
# this will download requirements for Kornia
|
def download_from_hf(model_class:object, model_name:str):
|
||||||
def download_kornia():
|
return model_class.from_pretrained(model_name,
|
||||||
print('Installing Kornia requirements (ignore deprecation errors)...', end='')
|
cache_dir=os.path.join(Root_dir,Model_dir,model_name),
|
||||||
sys.stdout.flush()
|
resume_download=True
|
||||||
import kornia
|
)
|
||||||
print('...success')
|
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_clip():
|
def download_clip():
|
||||||
print('Loading CLIP model (ignore deprecation errors)...',end='')
|
print('Loading CLIP model (ignore deprecation errors)...',end='',file=sys.stderr)
|
||||||
sys.stdout.flush()
|
|
||||||
version = 'openai/clip-vit-large-patch14'
|
version = 'openai/clip-vit-large-patch14'
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(version)
|
download_from_hf(CLIPTokenizer,version)
|
||||||
transformer = CLIPTextModel.from_pretrained(version)
|
download_from_hf(CLIPTextModel,version)
|
||||||
print('...success')
|
print('...success',file=sys.stderr)
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_gfpgan():
|
def download_gfpgan():
|
||||||
@ -464,7 +465,7 @@ def download_gfpgan():
|
|||||||
if not os.path.exists(model_dest):
|
if not os.path.exists(model_dest):
|
||||||
print(f'Downloading gfpgan model file {model_url}...',end='')
|
print(f'Downloading gfpgan model file {model_url}...',end='')
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
request.urlretrieve(model_url,model_dest)
|
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
|
||||||
print('...success')
|
print('...success')
|
||||||
except Exception:
|
except Exception:
|
||||||
print('Error loading GFPGAN:')
|
print('Error loading GFPGAN:')
|
||||||
@ -472,18 +473,18 @@ def download_gfpgan():
|
|||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_codeformer():
|
def download_codeformer():
|
||||||
print('Installing CodeFormer model file...',end='')
|
print('Installing CodeFormer model file...',end='',file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||||
model_dest = 'ldm/invoke/restoration/codeformer/weights/codeformer.pth'
|
model_dest = os.path.join(Root_dir,'models/codeformer/codeformer.pth')
|
||||||
if not os.path.exists(model_dest):
|
if not os.path.exists(model_dest):
|
||||||
print('Downloading codeformer model file...')
|
print('Downloading codeformer model file...')
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
request.urlretrieve(model_url,model_dest)
|
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
|
||||||
except Exception:
|
except Exception:
|
||||||
print('Error loading CodeFormer:')
|
print('Error loading CodeFormer:')
|
||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
print('...success')
|
print('...success',file=sys.stderr)
|
||||||
|
|
||||||
#---------------------------------------------
|
#---------------------------------------------
|
||||||
def download_clipseg():
|
def download_clipseg():
|
||||||
@ -497,7 +498,7 @@ def download_clipseg():
|
|||||||
if not os.path.exists(model_dest):
|
if not os.path.exists(model_dest):
|
||||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||||
if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'):
|
if not os.path.exists(f'{model_dest}/rd64-uni-refined.pth'):
|
||||||
request.urlretrieve(model_url,weights_zip)
|
request.urlretrieve(model_url,weights_zip,ProgressBar(os.path.basename(model_dest)))
|
||||||
with zipfile.ZipFile(weights_zip,'r') as zip:
|
with zipfile.ZipFile(weights_zip,'r') as zip:
|
||||||
zip.extractall('models/clipseg')
|
zip.extractall('models/clipseg')
|
||||||
os.remove(weights_zip)
|
os.remove(weights_zip)
|
||||||
@ -519,7 +520,7 @@ def download_clipseg():
|
|||||||
|
|
||||||
#-------------------------------------
|
#-------------------------------------
|
||||||
def download_safety_checker():
|
def download_safety_checker():
|
||||||
print('Installing safety model for NSFW content detection...',end='')
|
print('Installing safety model for NSFW content detection...',end='',file=sys.stderr)
|
||||||
try:
|
try:
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
@ -528,9 +529,25 @@ def download_safety_checker():
|
|||||||
print(traceback.format_exc())
|
print(traceback.format_exc())
|
||||||
return
|
return
|
||||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||||
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
download_from_hf(AutoFeatureExtractor,safety_model_id)
|
||||||
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
download_from_hf(StableDiffusionSafetyChecker,safety_model_id)
|
||||||
print('...success')
|
print('...success',file=sys.stderr)
|
||||||
|
|
||||||
|
#-------------------------------------
|
||||||
|
class ProgressBar():
|
||||||
|
def __init__(self,model_name='file'):
|
||||||
|
self.pbar = None
|
||||||
|
self.name = model_name
|
||||||
|
|
||||||
|
def __call__(self, block_num, block_size, total_size):
|
||||||
|
if not self.pbar:
|
||||||
|
self.pbar=tqdm(desc=self.name,
|
||||||
|
initial=0,
|
||||||
|
unit='iB',
|
||||||
|
unit_scale=True,
|
||||||
|
total=total_size)
|
||||||
|
downloaded = block_num * block_size
|
||||||
|
self.pbar.update(downloaded)
|
||||||
|
|
||||||
#-------------------------------------
|
#-------------------------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -540,13 +557,26 @@ if __name__ == '__main__':
|
|||||||
action=argparse.BooleanOptionalAction,
|
action=argparse.BooleanOptionalAction,
|
||||||
default=True,
|
default=True,
|
||||||
help='run in interactive mode (default)')
|
help='run in interactive mode (default)')
|
||||||
|
parser.add_argument('--yes','-y',
|
||||||
|
dest='yes_to_all',
|
||||||
|
action=argparse.BooleanOptionalAction,
|
||||||
|
default=True,
|
||||||
|
help='run in interactive mode (default)')
|
||||||
parser.add_argument('--config_file',
|
parser.add_argument('--config_file',
|
||||||
'-c',
|
'-c',
|
||||||
dest='config_file',
|
dest='config_file',
|
||||||
type=str,
|
type=str,
|
||||||
default='./configs/models.yaml',
|
default='./configs/models.yaml',
|
||||||
help='path to configuration file to create')
|
help='path to configuration file to create')
|
||||||
|
parser.add_argument('--root',
|
||||||
|
dest='root',
|
||||||
|
type=str,
|
||||||
|
default='.',
|
||||||
|
help='path to root of install directory')
|
||||||
opt = parser.parse_args()
|
opt = parser.parse_args()
|
||||||
|
|
||||||
|
# setting a global here
|
||||||
|
Root_dir = os.path.expanduser(opt.root)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if opt.interactive:
|
if opt.interactive:
|
||||||
@ -565,7 +595,6 @@ if __name__ == '__main__':
|
|||||||
update_config_file(successfully_downloaded,opt)
|
update_config_file(successfully_downloaded,opt)
|
||||||
print('\n** DOWNLOADING SUPPORT MODELS **')
|
print('\n** DOWNLOADING SUPPORT MODELS **')
|
||||||
download_bert()
|
download_bert()
|
||||||
download_kornia()
|
|
||||||
download_clip()
|
download_clip()
|
||||||
download_gfpgan()
|
download_gfpgan()
|
||||||
download_codeformer()
|
download_codeformer()
|
||||||
|
Loading…
Reference in New Issue
Block a user