realesrgan and facexlib now download models to correct directory

- fix issue in which both realesrgan and facexlib were downloading
  weight files to source directory

- cleaned up status reporting in load_models.py
This commit is contained in:
Lincoln Stein 2022-11-18 19:29:29 +00:00
parent 303431be89
commit a1e5f17d1e
2 changed files with 43 additions and 49 deletions

View File

@ -1,7 +1,9 @@
import torch
import warnings
import numpy as np
import os
from ldm.invoke.globals import Globals
from PIL import Image
@ -24,7 +26,7 @@ class ESRGAN():
from realesrgan import RealESRGANer
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
model_path = os.path.join(Globals.root,'models/realesrgan/realesr-general-x4v3.pth')
scale = 4
bg_upsampler = RealESRGANer(

View File

@ -278,8 +278,8 @@ def download_weight_datasets(models:dict, access_token:str):
for mod in models.keys():
repo_id = Datasets[mod]['repo_id']
filename = Datasets[mod]['file']
print(os.path.join(Globals.root,Model_dir,Weights_dir))
success = download_with_resume(
print(os.path.join(Globals.root,Model_dir,Weights_dir), file=sys.stderr)
success = hf_download_with_resume(
repo_id=repo_id,
model_dir=os.path.join(Globals.root,Model_dir,Weights_dir),
model_name=filename,
@ -301,7 +301,7 @@ def download_weight_datasets(models:dict, access_token:str):
return successful
#---------------------------------------------
def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
model_dest = os.path.join(model_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
@ -349,6 +349,23 @@ def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_toke
print(f'An error occurred while downloading {model_name}: {str(e)}')
return False
return True
#---------------------------------------------
def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
try:
print(f'Installing {label} model file {model_url}...',end='',file=sys.stderr)
if not os.path.exists(model_dest):
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
print('',file=sys.stderr)
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
print('...downloaded successfully', file=sys.stderr)
else:
print('...exists', file=sys.stderr)
except Exception:
print('...download failed')
print(f'Error downloading {label} model')
print(traceback.format_exc())
#---------------------------------------------
def update_config_file(successfully_downloaded:dict,opt:dict):
@ -421,6 +438,7 @@ def download_bert():
#---------------------------------------------
def download_from_hf(model_class:object, model_name:str):
print('',file=sys.stderr) # to prevent tqdm from overwriting
return model_class.from_pretrained(model_name,
cache_dir=os.path.join(Globals.root,Model_dir,model_name),
resume_download=True
@ -428,34 +446,23 @@ def download_from_hf(model_class:object, model_name:str):
#---------------------------------------------
def download_clip():
print('Installing CLIP model (ignore deprecation errors)...',end='',file=sys.stderr)
print('Installing CLIP model (ignore deprecation errors)...',file=sys.stderr)
version = 'openai/clip-vit-large-patch14'
print('Tokenizer...',file=sys.stderr, end='')
download_from_hf(CLIPTokenizer,version)
print('Text model...',file=sys.stderr, end='')
download_from_hf(CLIPTextModel,version)
print('...success',file=sys.stderr)
#---------------------------------------------
def download_realesrgan():
print('Installing models from RealESRGAN and facexlib (ignore deprecation errors)...',end='',file=sys.stderr)
try:
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(
scale=4,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
)
FaceRestoreHelper(1, det_model='retinaface_resnet50')
print('...success',file=sys.stderr)
except Exception:
print('Error loading ESRGAN:')
print(traceback.format_exc())
print('Installing models from RealESRGAN...',file=sys.stderr)
model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
model_dest = os.path.join(Globals.root,'models/realesrgan/realesr-general-x4v3.pth')
download_with_progress_bar(model_url, model_dest, 'RealESRGAN')
def download_gfpgan():
print('Installing GFPGAN models...',end='',file=sys.stderr)
print('Installing GFPGAN models...',file=sys.stderr)
for model in (
[
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
@ -471,35 +478,18 @@ def download_gfpgan():
],
):
model_url,model_dest = model[0],os.path.join(Globals.root,model[1])
try:
if not os.path.exists(model_dest):
print(f'Downloading gfpgan model file {model_url}...',end='')
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
print('...success')
except Exception:
print('Error loading GFPGAN:')
print(traceback.format_exc())
print('...success',file=sys.stderr)
download_with_progress_bar(model_url, model_dest, 'GFPGAN weights')
#---------------------------------------------
def download_codeformer():
print('Installing CodeFormer model file...',end='',file=sys.stderr)
try:
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = os.path.join(Globals.root,'models/codeformer/codeformer.pth')
if not os.path.exists(model_dest):
print('Downloading codeformer model file...')
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
except Exception:
print('Error loading CodeFormer:')
print(traceback.format_exc())
print('...success',file=sys.stderr)
print('Installing CodeFormer model file...',file=sys.stderr)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = os.path.join(Globals.root,'models/codeformer/codeformer.pth')
download_with_progress_bar(model_url, model_dest, 'CodeFormer')
#---------------------------------------------
def download_clipseg():
print('Installing clipseg model for text-based masking...',end='')
print('Installing clipseg model for text-based masking...',end='', file=sys.stderr)
import zipfile
try:
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
@ -528,11 +518,11 @@ def download_clipseg():
except Exception:
print('Error installing clipseg model:')
print(traceback.format_exc())
print('...success')
print('...success',file=sys.stderr)
#-------------------------------------
def download_safety_checker():
print('Installing safety model for NSFW content detection...',end='',file=sys.stderr)
print('Installing safety model for NSFW content detection...',file=sys.stderr)
try:
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
@ -541,7 +531,9 @@ def download_safety_checker():
print(traceback.format_exc())
return
safety_model_id = "CompVis/stable-diffusion-safety-checker"
print('AutoFeatureExtractor...', end='',file=sys.stderr)
download_from_hf(AutoFeatureExtractor,safety_model_id)
print('StableDiffusionSafetyChecker...', end='',file=sys.stderr)
download_from_hf(StableDiffusionSafetyChecker,safety_model_id)
print('...success',file=sys.stderr)