mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
ESRGAN Improvements
This commit is contained in:
parent
d63897fc39
commit
d80fff70f2
@ -594,7 +594,7 @@ class Args(object):
|
|||||||
'--upscale',
|
'--upscale',
|
||||||
nargs='+',
|
nargs='+',
|
||||||
type=float,
|
type=float,
|
||||||
help='Scale factor (2, 4) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75',
|
help='Scale factor (1, 2, 3, 4, etc..) for upscaling final output followed by upscaling strength (0-1.0). If strength not specified, defaults to 0.75',
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
postprocessing_group.add_argument(
|
postprocessing_group.add_argument(
|
||||||
|
@ -14,73 +14,53 @@ class ESRGAN():
|
|||||||
else:
|
else:
|
||||||
use_half_precision = True
|
use_half_precision = True
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self, upsampler_scale):
|
def load_esrgan_bg_upsampler(self):
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
use_half_precision = False
|
use_half_precision = False
|
||||||
else:
|
else:
|
||||||
use_half_precision = True
|
use_half_precision = True
|
||||||
|
|
||||||
model_path = {
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
from realesrgan import RealESRGANer
|
||||||
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
|
||||||
}
|
|
||||||
|
|
||||||
if upsampler_scale not in model_path:
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||||
return None
|
model_path = 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth'
|
||||||
else:
|
scale = 4
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
||||||
from realesrgan import RealESRGANer
|
|
||||||
|
|
||||||
if upsampler_scale == 4:
|
bg_upsampler = RealESRGANer(
|
||||||
model = RRDBNet(
|
scale=scale,
|
||||||
num_in_ch=3,
|
model_path=model_path,
|
||||||
num_out_ch=3,
|
model=model,
|
||||||
num_feat=64,
|
tile=self.bg_tile_size,
|
||||||
num_block=23,
|
tile_pad=10,
|
||||||
num_grow_ch=32,
|
pre_pad=0,
|
||||||
scale=4,
|
half=use_half_precision,
|
||||||
)
|
)
|
||||||
if upsampler_scale == 2:
|
|
||||||
model = RRDBNet(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
bg_upsampler = RealESRGANer(
|
|
||||||
scale=upsampler_scale,
|
|
||||||
model_path=model_path[upsampler_scale],
|
|
||||||
model=model,
|
|
||||||
tile=self.bg_tile_size,
|
|
||||||
tile_pad=10,
|
|
||||||
pre_pad=0,
|
|
||||||
half=use_half_precision,
|
|
||||||
)
|
|
||||||
|
|
||||||
return bg_upsampler
|
return bg_upsampler
|
||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
|
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
|
||||||
if seed is not None:
|
|
||||||
print(
|
|
||||||
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
|
||||||
)
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upsampler = self.load_esrgan_bg_upsampler(upsampler_scale)
|
upsampler = self.load_esrgan_bg_upsampler()
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
|
|
||||||
|
if upsampler_scale == 0:
|
||||||
|
print('>> Real-ESRGAN: Invalid scaling option. Image not upscaled.')
|
||||||
|
return image
|
||||||
|
|
||||||
|
if seed is not None:
|
||||||
|
print(
|
||||||
|
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
||||||
|
)
|
||||||
|
|
||||||
output, _ = upsampler.enhance(
|
output, _ = upsampler.enhance(
|
||||||
np.array(image, dtype=np.uint8),
|
np.array(image, dtype=np.uint8),
|
||||||
outscale=upsampler_scale,
|
outscale=upsampler_scale,
|
||||||
|
@ -721,14 +721,6 @@ class Generate:
|
|||||||
for r in image_list:
|
for r in image_list:
|
||||||
image, seed = r
|
image, seed = r
|
||||||
try:
|
try:
|
||||||
if upscale is not None:
|
|
||||||
if self.esrgan is not None:
|
|
||||||
if len(upscale) < 2:
|
|
||||||
upscale.append(0.75)
|
|
||||||
image = self.esrgan.process(
|
|
||||||
image, upscale[1], seed, int(upscale[0]))
|
|
||||||
else:
|
|
||||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
|
||||||
if strength > 0:
|
if strength > 0:
|
||||||
if self.gfpgan is not None or self.codeformer is not None:
|
if self.gfpgan is not None or self.codeformer is not None:
|
||||||
if facetool == 'gfpgan':
|
if facetool == 'gfpgan':
|
||||||
@ -744,6 +736,14 @@ class Generate:
|
|||||||
image = self.codeformer.process(image=image, strength=strength, device=cf_device, seed=seed, fidelity=codeformer_fidelity)
|
image = self.codeformer.process(image=image, strength=strength, device=cf_device, seed=seed, fidelity=codeformer_fidelity)
|
||||||
else:
|
else:
|
||||||
print(">> Face Restoration is disabled.")
|
print(">> Face Restoration is disabled.")
|
||||||
|
if upscale is not None:
|
||||||
|
if self.esrgan is not None:
|
||||||
|
if len(upscale) < 2:
|
||||||
|
upscale.append(0.75)
|
||||||
|
image = self.esrgan.process(
|
||||||
|
image, upscale[1], seed, int(upscale[0]))
|
||||||
|
else:
|
||||||
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(
|
print(
|
||||||
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
|
f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}'
|
||||||
|
@ -49,33 +49,13 @@ except ModuleNotFoundError:
|
|||||||
if gfpgan:
|
if gfpgan:
|
||||||
print('Loading models from RealESRGAN and facexlib')
|
print('Loading models from RealESRGAN and facexlib')
|
||||||
try:
|
try:
|
||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||||
|
|
||||||
RealESRGANer(
|
|
||||||
scale=2,
|
|
||||||
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
|
||||||
model=RRDBNet(
|
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=2,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
RealESRGANer(
|
RealESRGANer(
|
||||||
scale=4,
|
scale=4,
|
||||||
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth',
|
||||||
model=RRDBNet(
|
model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu')
|
||||||
num_in_ch=3,
|
|
||||||
num_out_ch=3,
|
|
||||||
num_feat=64,
|
|
||||||
num_block=23,
|
|
||||||
num_grow_ch=32,
|
|
||||||
scale=4,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
FaceRestoreHelper(1, det_model='retinaface_resnet50')
|
FaceRestoreHelper(1, det_model='retinaface_resnet50')
|
||||||
|
Loading…
Reference in New Issue
Block a user