mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
realesrgan inherits precision setting from main program
This commit is contained in:
parent
911c99f125
commit
5f42d08945
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import warnings
|
||||
import numpy as np
|
||||
from ldm.dream.devices import choose_precision, choose_torch_device
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@ -8,17 +9,12 @@ from PIL import Image
|
||||
class ESRGAN():
|
||||
def __init__(self, bg_tile_size=400) -> None:
|
||||
self.bg_tile_size = bg_tile_size
|
||||
device = torch.device(choose_torch_device())
|
||||
precision = choose_precision(device)
|
||||
use_half_precision = precision == 'float16'
|
||||
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
|
||||
def load_esrgan_bg_upsampler(self):
|
||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||
use_half_precision = False
|
||||
else:
|
||||
use_half_precision = True
|
||||
def load_esrgan_bg_upsampler(self, precision):
|
||||
use_half_precision = precision == 'float16'
|
||||
|
||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||
from realesrgan import RealESRGANer
|
||||
@ -39,13 +35,13 @@ class ESRGAN():
|
||||
|
||||
return bg_upsampler
|
||||
|
||||
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
|
||||
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2, precision: str = 'float16'):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
try:
|
||||
upsampler = self.load_esrgan_bg_upsampler()
|
||||
upsampler = self.load_esrgan_bg_upsampler(precision)
|
||||
except Exception:
|
||||
import traceback
|
||||
import sys
|
||||
|
@ -599,7 +599,8 @@ class Generate:
|
||||
opt,
|
||||
args,
|
||||
image_callback = callback,
|
||||
prefix = prefix
|
||||
prefix = prefix,
|
||||
precision = self.precision,
|
||||
)
|
||||
|
||||
elif tool is None:
|
||||
@ -770,7 +771,7 @@ class Generate:
|
||||
if len(upscale) < 2:
|
||||
upscale.append(0.75)
|
||||
image = self.esrgan.process(
|
||||
image, upscale[1], seed, int(upscale[0]))
|
||||
image, upscale[1], seed, int(upscale[0]), precision=self.precision)
|
||||
else:
|
||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||
except Exception as e:
|
||||
|
Loading…
Reference in New Issue
Block a user