diff --git a/ldm/invoke/restoration/realesrgan.py b/ldm/invoke/restoration/realesrgan.py index 4b83fcbb10..97c2de33fe 100644 --- a/ldm/invoke/restoration/realesrgan.py +++ b/ldm/invoke/restoration/realesrgan.py @@ -1,7 +1,9 @@ import torch import warnings import numpy as np +import os +from ldm.invoke.globals import Globals from PIL import Image @@ -24,7 +26,7 @@ class ESRGAN(): from realesrgan import RealESRGANer model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' + model_path = os.path.join(Globals.root,'models/realesrgan/realesr-general-x4v3.pth') scale = 4 bg_upsampler = RealESRGANer( diff --git a/scripts/load_models.py b/scripts/load_models.py index 51a0d492c4..062f081506 100755 --- a/scripts/load_models.py +++ b/scripts/load_models.py @@ -278,8 +278,8 @@ def download_weight_datasets(models:dict, access_token:str): for mod in models.keys(): repo_id = Datasets[mod]['repo_id'] filename = Datasets[mod]['file'] - print(os.path.join(Globals.root,Model_dir,Weights_dir)) - success = download_with_resume( + print(os.path.join(Globals.root,Model_dir,Weights_dir), file=sys.stderr) + success = hf_download_with_resume( repo_id=repo_id, model_dir=os.path.join(Globals.root,Model_dir,Weights_dir), model_name=filename, @@ -301,7 +301,7 @@ def download_weight_datasets(models:dict, access_token:str): return successful #--------------------------------------------- -def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool: +def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool: model_dest = os.path.join(model_dir, model_name) os.makedirs(model_dir, exist_ok=True) @@ -349,6 +349,23 @@ def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_toke print(f'An error occurred while downloading {model_name}: {str(e)}') return False return True + +#--------------------------------------------- +def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'): + try: + print(f'Installing {label} model file {model_url}...',end='',file=sys.stderr) + if not os.path.exists(model_dest): + os.makedirs(os.path.dirname(model_dest), exist_ok=True) + print('',file=sys.stderr) + request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest))) + print('...downloaded successfully', file=sys.stderr) + else: + print('...exists', file=sys.stderr) + except Exception: + print('...download failed') + print(f'Error downloading {label} model') + print(traceback.format_exc()) + #--------------------------------------------- def update_config_file(successfully_downloaded:dict,opt:dict): @@ -421,6 +438,7 @@ def download_bert(): #--------------------------------------------- def download_from_hf(model_class:object, model_name:str): + print('',file=sys.stderr) # to prevent tqdm from overwriting return model_class.from_pretrained(model_name, cache_dir=os.path.join(Globals.root,Model_dir,model_name), resume_download=True @@ -428,34 +446,23 @@ def download_from_hf(model_class:object, model_name:str): #--------------------------------------------- def download_clip(): - print('Installing CLIP model (ignore deprecation errors)...',end='',file=sys.stderr) + print('Installing CLIP model (ignore deprecation errors)...',file=sys.stderr) version = 'openai/clip-vit-large-patch14' + print('Tokenizer...',file=sys.stderr, end='') download_from_hf(CLIPTokenizer,version) + print('Text model...',file=sys.stderr, end='') download_from_hf(CLIPTextModel,version) print('...success',file=sys.stderr) #--------------------------------------------- def download_realesrgan(): - print('Installing models from RealESRGAN and facexlib (ignore deprecation errors)...',end='',file=sys.stderr) - try: - from realesrgan import RealESRGANer - from realesrgan.archs.srvgg_arch import SRVGGNetCompact - from facexlib.utils.face_restoration_helper import FaceRestoreHelper - - RealESRGANer( - scale=4, - model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', - model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') - ) - - FaceRestoreHelper(1, det_model='retinaface_resnet50') - print('...success',file=sys.stderr) - except Exception: - print('Error loading ESRGAN:') - print(traceback.format_exc()) + print('Installing models from RealESRGAN...',file=sys.stderr) + model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth' + model_dest = os.path.join(Globals.root,'models/realesrgan/realesr-general-x4v3.pth') + download_with_progress_bar(model_url, model_dest, 'RealESRGAN') def download_gfpgan(): - print('Installing GFPGAN models...',end='',file=sys.stderr) + print('Installing GFPGAN models...',file=sys.stderr) for model in ( [ 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', @@ -471,35 +478,18 @@ def download_gfpgan(): ], ): model_url,model_dest = model[0],os.path.join(Globals.root,model[1]) - try: - if not os.path.exists(model_dest): - print(f'Downloading gfpgan model file {model_url}...',end='') - os.makedirs(os.path.dirname(model_dest), exist_ok=True) - request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest))) - print('...success') - except Exception: - print('Error loading GFPGAN:') - print(traceback.format_exc()) - print('...success',file=sys.stderr) + download_with_progress_bar(model_url, model_dest, 'GFPGAN weights') #--------------------------------------------- def download_codeformer(): - print('Installing CodeFormer model file...',end='',file=sys.stderr) - try: - model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' - model_dest = os.path.join(Globals.root,'models/codeformer/codeformer.pth') - if not os.path.exists(model_dest): - print('Downloading codeformer model file...') - os.makedirs(os.path.dirname(model_dest), exist_ok=True) - request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest))) - except Exception: - print('Error loading CodeFormer:') - print(traceback.format_exc()) - print('...success',file=sys.stderr) - + print('Installing CodeFormer model file...',file=sys.stderr) + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' + model_dest = os.path.join(Globals.root,'models/codeformer/codeformer.pth') + download_with_progress_bar(model_url, model_dest, 'CodeFormer') + #--------------------------------------------- def download_clipseg(): - print('Installing clipseg model for text-based masking...',end='') + print('Installing clipseg model for text-based masking...',end='', file=sys.stderr) import zipfile try: model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download' @@ -528,11 +518,11 @@ def download_clipseg(): except Exception: print('Error installing clipseg model:') print(traceback.format_exc()) - print('...success') + print('...success',file=sys.stderr) #------------------------------------- def download_safety_checker(): - print('Installing safety model for NSFW content detection...',end='',file=sys.stderr) + print('Installing safety model for NSFW content detection...',file=sys.stderr) try: from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker from transformers import AutoFeatureExtractor @@ -541,7 +531,9 @@ def download_safety_checker(): print(traceback.format_exc()) return safety_model_id = "CompVis/stable-diffusion-safety-checker" + print('AutoFeatureExtractor...', end='',file=sys.stderr) download_from_hf(AutoFeatureExtractor,safety_model_id) + print('StableDiffusionSafetyChecker...', end='',file=sys.stderr) download_from_hf(StableDiffusionSafetyChecker,safety_model_id) print('...success',file=sys.stderr)