mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
realesrgan inherits precision setting from main program
This commit is contained in:
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from ldm.dream.devices import choose_precision, choose_torch_device
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -8,17 +9,12 @@ from PIL import Image
|
|||||||
class ESRGAN():
|
class ESRGAN():
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
self.bg_tile_size = bg_tile_size
|
self.bg_tile_size = bg_tile_size
|
||||||
|
device = torch.device(choose_torch_device())
|
||||||
|
precision = choose_precision(device)
|
||||||
|
use_half_precision = precision == 'float16'
|
||||||
|
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
def load_esrgan_bg_upsampler(self, precision):
|
||||||
use_half_precision = False
|
use_half_precision = precision == 'float16'
|
||||||
else:
|
|
||||||
use_half_precision = True
|
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self):
|
|
||||||
if not torch.cuda.is_available(): # CPU or MPS on M1
|
|
||||||
use_half_precision = False
|
|
||||||
else:
|
|
||||||
use_half_precision = True
|
|
||||||
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
@ -39,13 +35,13 @@ class ESRGAN():
|
|||||||
|
|
||||||
return bg_upsampler
|
return bg_upsampler
|
||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
|
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2, precision: str = 'float16'):
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upsampler = self.load_esrgan_bg_upsampler()
|
upsampler = self.load_esrgan_bg_upsampler(precision)
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
import sys
|
import sys
|
||||||
|
@ -599,7 +599,8 @@ class Generate:
|
|||||||
opt,
|
opt,
|
||||||
args,
|
args,
|
||||||
image_callback = callback,
|
image_callback = callback,
|
||||||
prefix = prefix
|
prefix = prefix,
|
||||||
|
precision = self.precision,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
@ -770,7 +771,7 @@ class Generate:
|
|||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
upscale.append(0.75)
|
upscale.append(0.75)
|
||||||
image = self.esrgan.process(
|
image = self.esrgan.process(
|
||||||
image, upscale[1], seed, int(upscale[0]))
|
image, upscale[1], seed, int(upscale[0]), precision=self.precision)
|
||||||
else:
|
else:
|
||||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Reference in New Issue
Block a user