mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Revert "realesrgan inherits precision setting from main program"
This reverts commit 5f42d08945
.
This fix was intended to solve issue #939, in which ESRGAN generates
dark images when upscaling 4X on certain GTX cards. However, the fix
apparently causes conflicts with some versions of the ESRGAN library,
and this fix will have to wait until after release of 2.0.
This commit is contained in:
@ -1,7 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
import warnings
|
import warnings
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ldm.dream.devices import choose_precision, choose_torch_device
|
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
@ -9,12 +8,17 @@ from PIL import Image
|
|||||||
class ESRGAN():
|
class ESRGAN():
|
||||||
def __init__(self, bg_tile_size=400) -> None:
|
def __init__(self, bg_tile_size=400) -> None:
|
||||||
self.bg_tile_size = bg_tile_size
|
self.bg_tile_size = bg_tile_size
|
||||||
device = torch.device(choose_torch_device())
|
|
||||||
precision = choose_precision(device)
|
|
||||||
use_half_precision = precision == 'float16'
|
|
||||||
|
|
||||||
def load_esrgan_bg_upsampler(self, precision):
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
use_half_precision = precision == 'float16'
|
use_half_precision = False
|
||||||
|
else:
|
||||||
|
use_half_precision = True
|
||||||
|
|
||||||
|
def load_esrgan_bg_upsampler(self):
|
||||||
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
||||||
|
use_half_precision = False
|
||||||
|
else:
|
||||||
|
use_half_precision = True
|
||||||
|
|
||||||
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
@ -35,13 +39,13 @@ class ESRGAN():
|
|||||||
|
|
||||||
return bg_upsampler
|
return bg_upsampler
|
||||||
|
|
||||||
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2, precision: str = 'float16'):
|
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||||
warnings.filterwarnings('ignore', category=UserWarning)
|
warnings.filterwarnings('ignore', category=UserWarning)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
upsampler = self.load_esrgan_bg_upsampler(precision)
|
upsampler = self.load_esrgan_bg_upsampler()
|
||||||
except Exception:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
import sys
|
import sys
|
||||||
|
@ -599,8 +599,7 @@ class Generate:
|
|||||||
opt,
|
opt,
|
||||||
args,
|
args,
|
||||||
image_callback = callback,
|
image_callback = callback,
|
||||||
prefix = prefix,
|
prefix = prefix
|
||||||
precision = self.precision,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif tool is None:
|
elif tool is None:
|
||||||
@ -771,7 +770,7 @@ class Generate:
|
|||||||
if len(upscale) < 2:
|
if len(upscale) < 2:
|
||||||
upscale.append(0.75)
|
upscale.append(0.75)
|
||||||
image = self.esrgan.process(
|
image = self.esrgan.process(
|
||||||
image, upscale[1], seed, int(upscale[0]), precision=self.precision)
|
image, upscale[1], seed, int(upscale[0]))
|
||||||
else:
|
else:
|
||||||
print(">> ESRGAN is disabled. Image not upscaled.")
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
Reference in New Issue
Block a user