Revert "realesrgan inherits precision setting from main program"

This reverts commit 5f42d08945.

This fix was intended to solve issue #939, in which ESRGAN generates
dark images when upscaling 4X on certain GTX cards. However, the fix
apparently causes conflicts with some versions of the ESRGAN library,
and this fix will have to wait until after release of 2.0.
This commit is contained in:
Lincoln Stein 2022-10-06 20:52:38 -04:00
parent d60df54f69
commit 75165957c9
2 changed files with 14 additions and 11 deletions

View File

@ -1,7 +1,6 @@
import torch
import warnings
import numpy as np
from ldm.dream.devices import choose_precision, choose_torch_device
from PIL import Image
@ -9,12 +8,17 @@ from PIL import Image
class ESRGAN():
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
device = torch.device(choose_torch_device())
precision = choose_precision(device)
use_half_precision = precision == 'float16'
def load_esrgan_bg_upsampler(self, precision):
use_half_precision = precision == 'float16'
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from realesrgan import RealESRGANer
@ -35,13 +39,13 @@ class ESRGAN():
return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2, precision: str = 'float16'):
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler(precision)
upsampler = self.load_esrgan_bg_upsampler()
except Exception:
import traceback
import sys

View File

@ -599,8 +599,7 @@ class Generate:
opt,
args,
image_callback = callback,
prefix = prefix,
precision = self.precision,
prefix = prefix
)
elif tool is None:
@ -771,7 +770,7 @@ class Generate:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]), precision=self.precision)
image, upscale[1], seed, int(upscale[0]))
else:
print(">> ESRGAN is disabled. Image not upscaled.")
except Exception as e: