ESRGAN Improvements

This commit is contained in:
blessedcoolant 2022-09-26 11:11:59 +13:00 committed by Lincoln Stein
parent d63897fc39
commit d80fff70f2
4 changed files with 37 additions and 77 deletions

View File

@ -594,7 +594,7 @@ class Args(object):
'--upscale',
nargs='+',
type=float,
help='Scale factor (2, 4) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75',
help='Scale factor (1, 2, 3, 4, etc..) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75',
default=None,
)
postprocessing_group.add_argument(

View File

@ -14,73 +14,53 @@ class ESRGAN():
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self, upsampler_scale):
def load_esrgan_bg_upsampler(self):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
model_path = {
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from realesrgan import RealESRGANer
if upsampler_scale not in model_path:
return None
else:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
scale = 4
if upsampler_scale == 4:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=self.bg_tile_size,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
bg_upsampler = RealESRGANer(
scale=scale,
model_path=model_path,
model=model,
tile=self.bg_tile_size,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(upsampler_scale)
upsampler = self.load_esrgan_bg_upsampler()
except Exception:
import traceback
import sys
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
if upsampler_scale == 0:
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.')
return image
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
output, _ = upsampler.enhance(
np.array(image, dtype=np.uint8),
outscale=upsampler_scale,

View File

@ -721,14 +721,6 @@ class Generate:
for r in image_list:
image, seed = r
try:
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]))
else:
print(">> ESRGAN is disabled. Image not upscaled.")
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == 'gfpgan':
@ -744,6 +736,14 @@ class Generate:
image = self.codeformer.process(image=image, strength=strength, device=cf_device, seed=seed, fidelity=codeformer_fidelity)
else:
print(">> Face Restoration is disabled.")
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]))
else:
print(">> ESRGAN is disabled. Image not upscaled.")
except Exception as e:
print(
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'

View File

@ -49,33 +49,13 @@ except ModuleNotFoundError:
if gfpgan:
print('Loading models from RealESRGAN and facexlib')
try:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
RealESRGANer(
scale=2,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
),
)
RealESRGANer(
scale=4,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
model=RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
),
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
)
FaceRestoreHelper(1, det_model='retinaface_resnet50')