mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Enable upscaling on m1 (#474)
This commit is contained in:
parent
0881d429f2
commit
defafc0e8e
@ -75,52 +75,50 @@ def run_gfpgan(image, strength, seed, upsampler_scale=4):
|
|||||||
|
|
||||||
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
|
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
|
||||||
if bg_upsampler == 'realesrgan':
|
if bg_upsampler == 'realesrgan':
|
||||||
if not torch.cuda.is_available(): # CPU
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
warnings.warn(
|
use_half_precision = False
|
||||||
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
|
|
||||||
'If you really want to use it, please modify the corresponding codes.'
|
|
||||||
)
|
|
||||||
bg_upsampler = None
|
|
||||||
else:
|
else:
|
||||||
model_path = {
|
use_half_precision = True
|
||||||
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
|
||||||
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
|
||||||
}
|
|
||||||
|
|
||||||
if upsampler_scale not in model_path:
|
model_path = {
|
||||||
return None
|
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
||||||
|
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
||||||
|
}
|
||||||
|
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
if upsampler_scale not in model_path:
|
||||||
from realesrgan import RealESRGANer
|
return None
|
||||||
|
|
||||||
if upsampler_scale == 4:
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
model = RRDBNet(
|
from realesrgan import RealESRGANer
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=4,
|
|
||||||
)
|
|
||||||
if upsampler_scale == 2:
|
|
||||||
model = RRDBNet(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
bg_upsampler = RealESRGANer(
|
if upsampler_scale == 4:
|
||||||
scale=upsampler_scale,
|
model = RRDBNet(
|
||||||
model_path=model_path[upsampler_scale],
|
num_in_ch=3,
|
||||||
model=model,
|
num_out_ch=3,
|
||||||
tile=bg_tile,
|
num_feat=64,
|
||||||
tile_pad=10,
|
num_block=23,
|
||||||
pre_pad=0,
|
num_grow_ch=32,
|
||||||
half=True,
|
scale=4,
|
||||||
) # need to set False in CPU mode
|
)
|
||||||
|
if upsampler_scale == 2:
|
||||||
|
model = RRDBNet(
|
||||||
|
num_in_ch=3,
|
||||||
|
num_out_ch=3,
|
||||||
|
num_feat=64,
|
||||||
|
num_block=23,
|
||||||
|
num_grow_ch=32,
|
||||||
|
scale=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
bg_upsampler = RealESRGANer(
|
||||||
|
scale=upsampler_scale,
|
||||||
|
model_path=model_path[upsampler_scale],
|
||||||
|
model=model,
|
||||||
|
tile=bg_tile,
|
||||||
|
tile_pad=10,
|
||||||
|
pre_pad=0,
|
||||||
|
half=use_half_precision,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
bg_upsampler = None
|
bg_upsampler = None
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user