ESRGAN Improvements

This commit is contained in:
blessedcoolant 2022-09-26 11:11:59 +13:00 committed by Lincoln Stein
parent d63897fc39
commit d80fff70f2
4 changed files with 37 additions and 77 deletions

View File

@ -594,7 +594,7 @@ class Args(object):
'--upscale', '--upscale',
nargs='+', nargs='+',
type=float, type=float,
help='Scale factor (2, 4) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75', help='Scale factor (1, 2, 3, 4, etc..) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75',
default=None, default=None,
) )
postprocessing_group.add_argument( postprocessing_group.add_argument(

View File

@ -14,73 +14,53 @@ class ESRGAN():
else: else:
use_half_precision = True use_half_precision = True
def load_esrgan_bg_upsampler(self, upsampler_scale): def load_esrgan_bg_upsampler(self):
if not torch.cuda.is_available(): # CPU or MPS on M1 if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False use_half_precision = False
else: else:
use_half_precision = True use_half_precision = True
model_path = { from realesrgan.archs.srvgg_arch import SRVGGNetCompact
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth', from realesrgan import RealESRGANer
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
if upsampler_scale not in model_path: model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
return None model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
else: scale = 4
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
if upsampler_scale == 4: bg_upsampler = RealESRGANer(
model = RRDBNet( scale=scale,
num_in_ch=3, model_path=model_path,
num_out_ch=3, model=model,
num_feat=64, tile=self.bg_tile_size,
num_block=23, tile_pad=10,
num_grow_ch=32, pre_pad=0,
scale=4, half=use_half_precision,
) )
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=self.bg_tile_size,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2): def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
try: try:
upsampler = self.load_esrgan_bg_upsampler(upsampler_scale) upsampler = self.load_esrgan_bg_upsampler()
except Exception: except Exception:
import traceback import traceback
import sys import sys
print('>> Error loading Real-ESRGAN:', file=sys.stderr) print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.')
return image
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
output, _ = upsampler.enhance( output, _ = upsampler.enhance(
np.array(image, dtype=np.uint8), np.array(image, dtype=np.uint8),
outscale=upsampler_scale, outscale=upsampler_scale,

View File

@ -721,14 +721,6 @@ class Generate:
for r in image_list: for r in image_list:
image, seed = r image, seed = r
try: try:
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]))
else:
print(">> ESRGAN is disabled. Image not upscaled.")
if strength > 0: if strength > 0:
if self.gfpgan is not None or self.codeformer is not None: if self.gfpgan is not None or self.codeformer is not None:
if facetool == 'gfpgan': if facetool == 'gfpgan':
@ -744,6 +736,14 @@ class Generate:
image = self.codeformer.process(image=image, strength=strength, device=cf_device, seed=seed, fidelity=codeformer_fidelity) image = self.codeformer.process(image=image, strength=strength, device=cf_device, seed=seed, fidelity=codeformer_fidelity)
else: else:
print(">> Face Restoration is disabled.") print(">> Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]))
else:
print(">> ESRGAN is disabled. Image not upscaled.")
except Exception as e: except Exception as e:
print( print(
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}' f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'

View File

@ -49,33 +49,13 @@ except ModuleNotFoundError:
if gfpgan: if gfpgan:
print('Loading models from RealESRGAN and facexlib') print('Loading models from RealESRGAN and facexlib')
try: try:
from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
),
)
RealESRGANer( RealESRGANer(
scale=4, scale=4,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
model=RRDBNet( model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
),
) )
FaceRestoreHelper(1, det_model='retinaface_resnet50') FaceRestoreHelper(1, det_model='retinaface_resnet50')