Change denoise_str to an arg instead of a class variable

This commit is contained in:
blessedcoolant 2023-02-09 20:16:23 +13:00
parent 5590c73af2
commit 0503680efa
3 changed files with 9 additions and 10 deletions

View File

@ -1057,7 +1057,7 @@ def load_face_restoration(opt):
else:
print('>> Face restoration disabled')
if opt.esrgan:
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile, opt.esrgan_denoise_str)
esrgan = restoration.load_esrgan(opt.esrgan_bg_tile)
else:
print('>> Upscaling disabled')
else:

View File

@ -31,8 +31,8 @@ class Restoration():
return CodeFormerRestoration()
# Upscale Models
def load_esrgan(self, esrgan_bg_tile=400, denoise_str=0.9):
def load_esrgan(self, esrgan_bg_tile=400):
from ldm.invoke.restoration.realesrgan import ESRGAN
esrgan = ESRGAN(esrgan_bg_tile, denoise_str)
esrgan = ESRGAN(esrgan_bg_tile)
print('>> ESRGAN Initialized')
return esrgan;

View File

@ -8,16 +8,15 @@ from PIL import Image
from PIL.Image import Image as ImageType
class ESRGAN():
def __init__(self, bg_tile_size=400, denoise_str=0.9) -> None:
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
self.denoise_str=denoise_str
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self):
def load_esrgan_bg_upsampler(self, denoise_str):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
@ -36,7 +35,7 @@ class ESRGAN():
model_path=[model_path, wdn_model_path],
model=model,
tile=self.bg_tile_size,
dni_weight=[self.denoise_str, 1 - self.denoise_str],
dni_weight=[denoise_str, 1 - denoise_str],
tile_pad=10,
pre_pad=0,
half=use_half_precision,
@ -44,13 +43,13 @@ class ESRGAN():
return bg_upsampler
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2):
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2, denoise_str: float = 0.75):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler()
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
except Exception:
import traceback
import sys
@ -63,7 +62,7 @@ class ESRGAN():
if seed is not None:
print(
f'>> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{self.denoise_str}'
f'>> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}'
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")