Enable upscaling on m1 (#474)

This commit is contained in:
Dominic Letz 2022-09-11 18:51:01 +02:00 committed by GitHub
parent 0881d429f2
commit defafc0e8e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -75,52 +75,50 @@ def run_gfpgan(image, strength, seed, upsampler_scale=4):
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400): def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
if bg_upsampler == 'realesrgan': if bg_upsampler == 'realesrgan':
if not torch.cuda.is_available(): # CPU if not torch.cuda.is_available(): # CPU or MPS on M1
warnings.warn( use_half_precision = False
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
'If you really want to use it, please modify the corresponding codes.'
)
bg_upsampler = None
else: else:
model_path = { use_half_precision = True
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
if upsampler_scale not in model_path: model_path = {
return None 2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
}
from basicsr.archs.rrdbnet_arch import RRDBNet if upsampler_scale not in model_path:
from realesrgan import RealESRGANer return None
if upsampler_scale == 4: from basicsr.archs.rrdbnet_arch import RRDBNet
model = RRDBNet( from realesrgan import RealESRGANer
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer( if upsampler_scale == 4:
scale=upsampler_scale, model = RRDBNet(
model_path=model_path[upsampler_scale], num_in_ch=3,
model=model, num_out_ch=3,
tile=bg_tile, num_feat=64,
tile_pad=10, num_block=23,
pre_pad=0, num_grow_ch=32,
half=True, scale=4,
) # need to set False in CPU mode )
if upsampler_scale == 2:
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=2,
)
bg_upsampler = RealESRGANer(
scale=upsampler_scale,
model_path=model_path[upsampler_scale],
model=model,
tile=bg_tile,
tile_pad=10,
pre_pad=0,
half=use_half_precision,
)
else: else:
bg_upsampler = None bg_upsampler = None