Revert "realesrgan inherits precision setting from main program"

This reverts commit 5f42d08945.

This fix was intended to solve issue #939, in which ESRGAN generates
dark images when upscaling 4X on certain GTX cards. However, the fix
apparently causes conflicts with some versions of the ESRGAN library,
and this fix will have to wait until after release of 2.0.
This commit is contained in:
Lincoln Stein
2022-10-06 20:52:38 -04:00
parent d60df54f69
commit 75165957c9
2 changed files with 14 additions and 11 deletions

View File

@ -1,7 +1,6 @@
import torch import torch
import warnings import warnings
import numpy as np import numpy as np
from ldm.dream.devices import choose_precision, choose_torch_device
from PIL import Image from PIL import Image
@ -9,12 +8,17 @@ from PIL import Image
class ESRGAN(): class ESRGAN():
def __init__(self, bg_tile_size=400) -> None: def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size self.bg_tile_size = bg_tile_size
device = torch.device(choose_torch_device())
precision = choose_precision(device)
use_half_precision = precision == 'float16'
def load_esrgan_bg_upsampler(self, precision): if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = precision == 'float16' use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
from realesrgan.archs.srvgg_arch import SRVGGNetCompact from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
@ -35,13 +39,13 @@ class ESRGAN():
return bg_upsampler return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2, precision: str = 'float16'): def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning) warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning) warnings.filterwarnings('ignore', category=UserWarning)
try: try:
upsampler = self.load_esrgan_bg_upsampler(precision) upsampler = self.load_esrgan_bg_upsampler()
except Exception: except Exception:
import traceback import traceback
import sys import sys

View File

@ -599,8 +599,7 @@ class Generate:
opt, opt,
args, args,
image_callback = callback, image_callback = callback,
prefix = prefix, prefix = prefix
precision = self.precision,
) )
elif tool is None: elif tool is None:
@ -771,7 +770,7 @@ class Generate:
if len(upscale) < 2: if len(upscale) < 2:
upscale.append(0.75) upscale.append(0.75)
image = self.esrgan.process( image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]), precision=self.precision) image, upscale[1], seed, int(upscale[0]))
else: else:
print(">> ESRGAN is disabled. Image not upscaled.") print(">> ESRGAN is disabled. Image not upscaled.")
except Exception as e: except Exception as e: