mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
realesrgan and facexlib now download models to correct directory
- fix issue in which both realesrgan and facexlib were downloading weight files to source directory - cleaned up status reporting in load_models.py
This commit is contained in:
parent
303431be89
commit
a1e5f17d1e
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
from ldm.invoke.globals import Globals
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@ -24,7 +26,7 @@ class ESRGAN():
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
||||
model_path = os.path.join(Globals.root,'models/realesrgan/realesr-general-x4v3.pth')
|
||||
scale = 4
|
||||
|
||||
bg_upsampler = RealESRGANer(
|
||||
|
@ -278,8 +278,8 @@ def download_weight_datasets(models:dict, access_token:str):
|
||||
for mod in models.keys():
|
||||
repo_id = Datasets[mod]['repo_id']
|
||||
filename = Datasets[mod]['file']
|
||||
print(os.path.join(Globals.root,Model_dir,Weights_dir))
|
||||
success = download_with_resume(
|
||||
print(os.path.join(Globals.root,Model_dir,Weights_dir), file=sys.stderr)
|
||||
success = hf_download_with_resume(
|
||||
repo_id=repo_id,
|
||||
model_dir=os.path.join(Globals.root,Model_dir,Weights_dir),
|
||||
model_name=filename,
|
||||
@ -301,7 +301,7 @@ def download_weight_datasets(models:dict, access_token:str):
|
||||
return successful
|
||||
|
||||
#---------------------------------------------
|
||||
def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
||||
def hf_download_with_resume(repo_id:str, model_dir:str, model_name:str, access_token:str=None)->bool:
|
||||
model_dest = os.path.join(model_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
@ -350,6 +350,23 @@ def download_with_resume(repo_id:str, model_dir:str, model_name:str, access_toke
|
||||
return False
|
||||
return True
|
||||
|
||||
#---------------------------------------------
|
||||
def download_with_progress_bar(model_url:str, model_dest:str, label:str='the'):
|
||||
try:
|
||||
print(f'Installing {label} model file {model_url}...',end='',file=sys.stderr)
|
||||
if not os.path.exists(model_dest):
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
print('',file=sys.stderr)
|
||||
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
|
||||
print('...downloaded successfully', file=sys.stderr)
|
||||
else:
|
||||
print('...exists', file=sys.stderr)
|
||||
except Exception:
|
||||
print('...download failed')
|
||||
print(f'Error downloading {label} model')
|
||||
print(traceback.format_exc())
|
||||
|
||||
|
||||
#---------------------------------------------
|
||||
def update_config_file(successfully_downloaded:dict,opt:dict):
|
||||
config_file = opt.config_file or Default_config_file
|
||||
@ -421,6 +438,7 @@ def download_bert():
|
||||
|
||||
#---------------------------------------------
|
||||
def download_from_hf(model_class:object, model_name:str):
|
||||
print('',file=sys.stderr) # to prevent tqdm from overwriting
|
||||
return model_class.from_pretrained(model_name,
|
||||
cache_dir=os.path.join(Globals.root,Model_dir,model_name),
|
||||
resume_download=True
|
||||
@ -428,34 +446,23 @@ def download_from_hf(model_class:object, model_name:str):
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clip():
|
||||
print('Installing CLIP model (ignore deprecation errors)...',end='',file=sys.stderr)
|
||||
print('Installing CLIP model (ignore deprecation errors)...',file=sys.stderr)
|
||||
version = 'openai/clip-vit-large-patch14'
|
||||
print('Tokenizer...',file=sys.stderr, end='')
|
||||
download_from_hf(CLIPTokenizer,version)
|
||||
print('Text model...',file=sys.stderr, end='')
|
||||
download_from_hf(CLIPTextModel,version)
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#---------------------------------------------
|
||||
def download_realesrgan():
|
||||
print('Installing models from RealESRGAN and facexlib (ignore deprecation errors)...',end='',file=sys.stderr)
|
||||
try:
|
||||
from realesrgan import RealESRGANer
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
|
||||
RealESRGANer(
|
||||
scale=4,
|
||||
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
|
||||
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||
)
|
||||
|
||||
FaceRestoreHelper(1, det_model='retinaface_resnet50')
|
||||
print('...success',file=sys.stderr)
|
||||
except Exception:
|
||||
print('Error loading ESRGAN:')
|
||||
print(traceback.format_exc())
|
||||
print('Installing models from RealESRGAN...',file=sys.stderr)
|
||||
model_url = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
||||
model_dest = os.path.join(Globals.root,'models/realesrgan/realesr-general-x4v3.pth')
|
||||
download_with_progress_bar(model_url, model_dest, 'RealESRGAN')
|
||||
|
||||
def download_gfpgan():
|
||||
print('Installing GFPGAN models...',end='',file=sys.stderr)
|
||||
print('Installing GFPGAN models...',file=sys.stderr)
|
||||
for model in (
|
||||
[
|
||||
'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth',
|
||||
@ -471,35 +478,18 @@ def download_gfpgan():
|
||||
],
|
||||
):
|
||||
model_url,model_dest = model[0],os.path.join(Globals.root,model[1])
|
||||
try:
|
||||
if not os.path.exists(model_dest):
|
||||
print(f'Downloading gfpgan model file {model_url}...',end='')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
|
||||
print('...success')
|
||||
except Exception:
|
||||
print('Error loading GFPGAN:')
|
||||
print(traceback.format_exc())
|
||||
print('...success',file=sys.stderr)
|
||||
download_with_progress_bar(model_url, model_dest, 'GFPGAN weights')
|
||||
|
||||
#---------------------------------------------
|
||||
def download_codeformer():
|
||||
print('Installing CodeFormer model file...',end='',file=sys.stderr)
|
||||
try:
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = os.path.join(Globals.root,'models/codeformer/codeformer.pth')
|
||||
if not os.path.exists(model_dest):
|
||||
print('Downloading codeformer model file...')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
request.urlretrieve(model_url,model_dest,ProgressBar(os.path.basename(model_dest)))
|
||||
except Exception:
|
||||
print('Error loading CodeFormer:')
|
||||
print(traceback.format_exc())
|
||||
print('...success',file=sys.stderr)
|
||||
print('Installing CodeFormer model file...',file=sys.stderr)
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = os.path.join(Globals.root,'models/codeformer/codeformer.pth')
|
||||
download_with_progress_bar(model_url, model_dest, 'CodeFormer')
|
||||
|
||||
#---------------------------------------------
|
||||
def download_clipseg():
|
||||
print('Installing clipseg model for text-based masking...',end='')
|
||||
print('Installing clipseg model for text-based masking...',end='', file=sys.stderr)
|
||||
import zipfile
|
||||
try:
|
||||
model_url = 'https://owncloud.gwdg.de/index.php/s/ioHbRzFx6th32hn/download'
|
||||
@ -528,11 +518,11 @@ def download_clipseg():
|
||||
except Exception:
|
||||
print('Error installing clipseg model:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
#-------------------------------------
|
||||
def download_safety_checker():
|
||||
print('Installing safety model for NSFW content detection...',end='',file=sys.stderr)
|
||||
print('Installing safety model for NSFW content detection...',file=sys.stderr)
|
||||
try:
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from transformers import AutoFeatureExtractor
|
||||
@ -541,7 +531,9 @@ def download_safety_checker():
|
||||
print(traceback.format_exc())
|
||||
return
|
||||
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
||||
print('AutoFeatureExtractor...', end='',file=sys.stderr)
|
||||
download_from_hf(AutoFeatureExtractor,safety_model_id)
|
||||
print('StableDiffusionSafetyChecker...', end='',file=sys.stderr)
|
||||
download_from_hf(StableDiffusionSafetyChecker,safety_model_id)
|
||||
print('...success',file=sys.stderr)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user