mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
114 lines
3.4 KiB
Python
114 lines
3.4 KiB
Python
import os
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import torch
|
|
from PIL import Image
|
|
from PIL.Image import Image as ImageType
|
|
|
|
from invokeai.backend.globals import Globals
|
|
|
|
|
|
class ESRGAN:
|
|
def __init__(self, bg_tile_size=400) -> None:
|
|
self.bg_tile_size = bg_tile_size
|
|
|
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
|
use_half_precision = False
|
|
else:
|
|
use_half_precision = True
|
|
|
|
def load_esrgan_bg_upsampler(self, denoise_str):
|
|
if not torch.cuda.is_available(): # CPU or MPS on M1
|
|
use_half_precision = False
|
|
else:
|
|
use_half_precision = True
|
|
|
|
from realesrgan import RealESRGANer
|
|
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
|
|
|
|
model = SRVGGNetCompact(
|
|
num_in_ch=3,
|
|
num_out_ch=3,
|
|
num_feat=64,
|
|
num_conv=32,
|
|
upscale=4,
|
|
act_type="prelu",
|
|
)
|
|
model_path = os.path.join(
|
|
Globals.root, "models/realesrgan/realesr-general-x4v3.pth"
|
|
)
|
|
wdn_model_path = os.path.join(
|
|
Globals.root, "models/realesrgan/realesr-general-wdn-x4v3.pth"
|
|
)
|
|
scale = 4
|
|
|
|
bg_upsampler = RealESRGANer(
|
|
scale=scale,
|
|
model_path=[model_path, wdn_model_path],
|
|
model=model,
|
|
tile=self.bg_tile_size,
|
|
dni_weight=[denoise_str, 1 - denoise_str],
|
|
tile_pad=10,
|
|
pre_pad=0,
|
|
half=use_half_precision,
|
|
)
|
|
|
|
return bg_upsampler
|
|
|
|
def process(
|
|
self,
|
|
image: ImageType,
|
|
strength: float,
|
|
seed: str = None,
|
|
upsampler_scale: int = 2,
|
|
denoise_str: float = 0.75,
|
|
):
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
|
|
try:
|
|
upsampler = self.load_esrgan_bg_upsampler(denoise_str)
|
|
except Exception:
|
|
import sys
|
|
import traceback
|
|
|
|
print(">> Error loading Real-ESRGAN:", file=sys.stderr)
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
if upsampler_scale == 0:
|
|
print(">> Real-ESRGAN: Invalid scaling option. Image not upscaled.")
|
|
return image
|
|
|
|
if seed is not None:
|
|
print(
|
|
f">> Real-ESRGAN Upscaling seed:{seed}, scale:{upsampler_scale}x, tile:{self.bg_tile_size}, denoise:{denoise_str}"
|
|
)
|
|
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
|
|
image = image.convert("RGB")
|
|
|
|
# REALSRGAN expects a BGR np array; make array and flip channels
|
|
bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1]
|
|
|
|
output, _ = upsampler.enhance(
|
|
bgr_image_array,
|
|
outscale=upsampler_scale,
|
|
alpha_upsampler="realesrgan",
|
|
)
|
|
|
|
# Flip the channels back to RGB
|
|
res = Image.fromarray(output[..., ::-1])
|
|
|
|
if strength < 1.0:
|
|
# Resize the image to the new image if the sizes have changed
|
|
if output.size != image.size:
|
|
image = image.resize(res.size)
|
|
res = Image.blend(image, res, strength)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
upsampler = None
|
|
|
|
return res
|