* Converts ESRGAN image input to RGB

- Also adds typing for image input.
- Partially resolves #1604

* ensure there are unmasked pixels before color matching

Co-authored-by: Kyle Schouviller <kyle0654@hotmail.com>
This commit is contained in:
psychedelicious 2022-11-30 01:34:38 +11:00 committed by GitHub
parent 40c3ab0181
commit 61cc41aa3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 11 deletions

View File

@ -141,6 +141,7 @@ class Generator():
np_init_rgb_pixels_masked = init_rgb_pixels[mask_pixels, :]
np_image_masked = np_image[mask_pixels, :]
if np_init_rgb_pixels_masked.size > 0:
init_means = np_init_rgb_pixels_masked.mean(axis=0)
init_std = np_init_rgb_pixels_masked.std(axis=0)
gen_means = np_image_masked.mean(axis=0)
@ -150,6 +151,8 @@ class Generator():
np_matched_result = np_image.copy()
np_matched_result[:,:,:] = (((np_matched_result[:,:,:].astype(np.float32) - gen_means[None,None,:]) / gen_std[None,None,:]) * init_std[None,None,:] + init_means[None,None,:]).clip(0, 255).astype(np.uint8)
matched_result = Image.fromarray(np_matched_result, mode='RGB')
else:
matched_result = Image.fromarray(np_image, mode='RGB')
# Blur the mask out (into init image) by specified amount
if mask_blur_radius > 0:

View File

@ -5,7 +5,7 @@ import os
from ldm.invoke.globals import Globals
from PIL import Image
from PIL.Image import Image as ImageType
class ESRGAN():
def __init__(self, bg_tile_size=400) -> None:
@ -41,7 +41,7 @@ class ESRGAN():
return bg_upsampler
def process(self, image, strength: float, seed: str = None, upsampler_scale: int = 2):
def process(self, image: ImageType, strength: float, seed: str = None, upsampler_scale: int = 2):
with warnings.catch_warnings():
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)
@ -62,6 +62,8 @@ class ESRGAN():
print(
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
)
# ESRGAN outputs images with partial transparency if given RGBA images; convert to RGB
image = image.convert("RGB")
# REALSRGAN expects a BGR np array; make array and flip channels
bgr_image_array = np.array(image, dtype=np.uint8)[...,::-1]