realesrgan inherits precision setting from main program

This commit is contained in:
Lincoln Stein 2022-10-06 12:23:30 -04:00
parent 911c99f125
commit 5f42d08945
2 changed files with 11 additions and 14 deletions

View File

@ -1,6 +1,7 @@
import torch
import warnings
import numpy as np
from ldm.dream.devices import choose_precision, choose_torch_device
from PIL import Image
@ -8,17 +9,12 @@ from PIL import Image
class ESRGAN():
def __init__(self, bg_tile_size=400) -> None:
self.bg_tile_size = bg_tile_size
device = torch.device(choose_torch_device())
precision = choose_precision(device)
use_half_precision = precision == 'float16'
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self):
if not torch.cuda.is_available(): # CPU or MPS on M1
use_half_precision = False
else:
use_half_precision = True
def load_esrgan_bg_upsampler(self, precision):
use_half_precision = precision == 'float16'
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
from realesrgan import RealESRGANer
@ -39,13 +35,13 @@ class ESRGAN():
return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2, precision: str = 'float16'):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
try:
upsampler = self.load_esrgan_bg_upsampler()
upsampler = self.load_esrgan_bg_upsampler(precision)
except Exception:
import traceback
import sys

View File

@ -599,7 +599,8 @@ class Generate:
opt,
args,
image_callback = callback,
prefix = prefix
prefix = prefix,
precision = self.precision,
)
elif tool is None:
@ -770,7 +771,7 @@ class Generate:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image, upscale[1], seed, int(upscale[0]))
image, upscale[1], seed, int(upscale[0]), precision=self.precision)
else:
print(">> ESRGAN is disabled. Image not upscaled.")
except Exception as e: