mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): remove dependency on realesrgan
We used the `RealESRGANer` utility class from the repo. It handled model loading and tiled upscaling logic. Unfortunately, it hasn't been updated in over a year, had no types, and annoyingly printed to console. I've adapted the class, cleaning it up a bit and removing the bits that are not relevant for us. Upscaling functionality is identical.
This commit is contained in:
parent
84629df49c
commit
2192210910
@ -2,16 +2,16 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import cv2 as cv
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from PIL import Image
|
||||
from pydantic import ConfigDict
|
||||
from realesrgan import RealESRGANer
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.backend.image_util.esrgan import RealESRGANer
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
|
||||
@ -94,7 +94,7 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
upsampler = RealESRGANer(
|
||||
scale=netscale,
|
||||
model_path=str(models_path / esrgan_model_path),
|
||||
model_path=models_path / esrgan_model_path,
|
||||
model=rrdbnet_model,
|
||||
half=False,
|
||||
tile=self.tile_size,
|
||||
@ -102,15 +102,15 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
|
||||
|
||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||
# TODO: This strips the alpha... is that okay?
|
||||
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
|
||||
|
||||
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
||||
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
|
||||
# upscaling, you'll need to add a resize node after this one.
|
||||
upscaled_image, img_mode = upsampler.enhance(cv_image)
|
||||
upscaled_image = upsampler.enhance(cv2_image)
|
||||
|
||||
# back to PIL
|
||||
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
||||
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
|
273
invokeai/backend/image_util/esrgan.py
Normal file
273
invokeai/backend/image_util/esrgan.py
Normal file
@ -0,0 +1,273 @@
|
||||
import math
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||
from cv2.typing import MatLike
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
|
||||
"""
|
||||
Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py
|
||||
|
||||
The adaptation here has a few changes:
|
||||
- Remove print statements, use `tqdm` to show progress
|
||||
- Remove unused "outscale" logic, which simply scales the final image to a given factor
|
||||
- Remove `dni_weight` logic, which was only used when multiple models were used
|
||||
- Remove logic to fetch models from network
|
||||
- Add types, rename a few things
|
||||
"""
|
||||
|
||||
|
||||
class ImageMode(str, Enum):
|
||||
L = "L"
|
||||
RGB = "RGB"
|
||||
RGBA = "RGBA"
|
||||
|
||||
|
||||
class RealESRGANer:
|
||||
"""A helper class for upsampling images with RealESRGAN.
|
||||
|
||||
Args:
|
||||
scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4.
|
||||
model_path (str): The path to the pretrained model. It can be urls (will first download it automatically).
|
||||
model (nn.Module): The defined network. Default: None.
|
||||
tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop
|
||||
input images into tiles, and then process each of them. Finally, they will be merged into one image.
|
||||
0 denotes for do not use tile. Default: 0.
|
||||
tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10.
|
||||
pre_pad (int): Pad the input images to avoid border artifacts. Default: 10.
|
||||
half (float): Whether to use half precision during inference. Default: False.
|
||||
"""
|
||||
|
||||
output: torch.Tensor
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scale: int,
|
||||
model_path: Path,
|
||||
model: RRDBNet,
|
||||
tile: int = 0,
|
||||
tile_pad: int = 10,
|
||||
pre_pad: int = 10,
|
||||
half: bool = False,
|
||||
) -> None:
|
||||
self.scale = scale
|
||||
self.tile_size = tile
|
||||
self.tile_pad = tile_pad
|
||||
self.pre_pad = pre_pad
|
||||
self.mod_scale: Optional[int] = None
|
||||
self.half = half
|
||||
self.device = choose_torch_device()
|
||||
|
||||
loadnet = torch.load(model_path, map_location=torch.device("cpu"))
|
||||
|
||||
# prefer to use params_ema
|
||||
if "params_ema" in loadnet:
|
||||
keyname = "params_ema"
|
||||
else:
|
||||
keyname = "params"
|
||||
|
||||
model.load_state_dict(loadnet[keyname], strict=True)
|
||||
model.eval()
|
||||
self.model = model.to(self.device)
|
||||
|
||||
if self.half:
|
||||
self.model = self.model.half()
|
||||
|
||||
def pre_process(self, img: MatLike) -> None:
|
||||
"""Pre-process, such as pre-pad and mod pad, so that the images can be divisible"""
|
||||
img_tensor: torch.Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))).float()
|
||||
self.img = img_tensor.unsqueeze(0).to(self.device)
|
||||
if self.half:
|
||||
self.img = self.img.half()
|
||||
|
||||
# pre_pad
|
||||
if self.pre_pad != 0:
|
||||
self.img = torch.nn.functional.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect")
|
||||
# mod pad for divisible borders
|
||||
if self.scale == 2:
|
||||
self.mod_scale = 2
|
||||
elif self.scale == 1:
|
||||
self.mod_scale = 4
|
||||
if self.mod_scale is not None:
|
||||
self.mod_pad_h, self.mod_pad_w = 0, 0
|
||||
_, _, h, w = self.img.size()
|
||||
if h % self.mod_scale != 0:
|
||||
self.mod_pad_h = self.mod_scale - h % self.mod_scale
|
||||
if w % self.mod_scale != 0:
|
||||
self.mod_pad_w = self.mod_scale - w % self.mod_scale
|
||||
self.img = torch.nn.functional.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect")
|
||||
|
||||
def process(self) -> None:
|
||||
# model inference
|
||||
self.output = self.model(self.img)
|
||||
|
||||
def tile_process(self) -> None:
|
||||
"""It will first crop input images to tiles, and then process each tile.
|
||||
Finally, all the processed tiles are merged into one images.
|
||||
|
||||
Modified from: https://github.com/ata4/esrgan-launcher
|
||||
"""
|
||||
batch, channel, height, width = self.img.shape
|
||||
output_height = height * self.scale
|
||||
output_width = width * self.scale
|
||||
output_shape = (batch, channel, output_height, output_width)
|
||||
|
||||
# start with black image
|
||||
self.output = self.img.new_zeros(output_shape)
|
||||
tiles_x = math.ceil(width / self.tile_size)
|
||||
tiles_y = math.ceil(height / self.tile_size)
|
||||
|
||||
# loop over all tiles
|
||||
total_steps = tiles_y * tiles_x
|
||||
for i in tqdm(range(total_steps), desc="Upscaling"):
|
||||
y = i // tiles_x
|
||||
x = i % tiles_x
|
||||
# extract tile from input image
|
||||
ofs_x = x * self.tile_size
|
||||
ofs_y = y * self.tile_size
|
||||
# input tile area on total image
|
||||
input_start_x = ofs_x
|
||||
input_end_x = min(ofs_x + self.tile_size, width)
|
||||
input_start_y = ofs_y
|
||||
input_end_y = min(ofs_y + self.tile_size, height)
|
||||
|
||||
# input tile area on total image with padding
|
||||
input_start_x_pad = max(input_start_x - self.tile_pad, 0)
|
||||
input_end_x_pad = min(input_end_x + self.tile_pad, width)
|
||||
input_start_y_pad = max(input_start_y - self.tile_pad, 0)
|
||||
input_end_y_pad = min(input_end_y + self.tile_pad, height)
|
||||
|
||||
# input tile dimensions
|
||||
input_tile_width = input_end_x - input_start_x
|
||||
input_tile_height = input_end_y - input_start_y
|
||||
input_tile = self.img[
|
||||
:,
|
||||
:,
|
||||
input_start_y_pad:input_end_y_pad,
|
||||
input_start_x_pad:input_end_x_pad,
|
||||
]
|
||||
|
||||
# upscale tile
|
||||
with torch.no_grad():
|
||||
output_tile = self.model(input_tile)
|
||||
|
||||
# output tile area on total image
|
||||
output_start_x = input_start_x * self.scale
|
||||
output_end_x = input_end_x * self.scale
|
||||
output_start_y = input_start_y * self.scale
|
||||
output_end_y = input_end_y * self.scale
|
||||
|
||||
# output tile area without padding
|
||||
output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale
|
||||
output_end_x_tile = output_start_x_tile + input_tile_width * self.scale
|
||||
output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale
|
||||
output_end_y_tile = output_start_y_tile + input_tile_height * self.scale
|
||||
|
||||
# put tile into output image
|
||||
self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[
|
||||
:,
|
||||
:,
|
||||
output_start_y_tile:output_end_y_tile,
|
||||
output_start_x_tile:output_end_x_tile,
|
||||
]
|
||||
|
||||
def post_process(self) -> torch.Tensor:
|
||||
# remove extra pad
|
||||
if self.mod_scale is not None:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[
|
||||
:,
|
||||
:,
|
||||
0 : h - self.mod_pad_h * self.scale,
|
||||
0 : w - self.mod_pad_w * self.scale,
|
||||
]
|
||||
# remove prepad
|
||||
if self.pre_pad != 0:
|
||||
_, _, h, w = self.output.size()
|
||||
self.output = self.output[
|
||||
:,
|
||||
:,
|
||||
0 : h - self.pre_pad * self.scale,
|
||||
0 : w - self.pre_pad * self.scale,
|
||||
]
|
||||
return self.output
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance(self, img: MatLike, esrgan_alpha_upscale: bool = True) -> npt.NDArray[Any]:
|
||||
np_img = img.astype(np.float32)
|
||||
alpha: Optional[np.ndarray] = None
|
||||
if np.max(np_img) > 256:
|
||||
# 16-bit image
|
||||
max_range = 65535
|
||||
else:
|
||||
max_range = 255
|
||||
np_img = np_img / max_range
|
||||
if len(np_img.shape) == 2:
|
||||
# grayscale image
|
||||
img_mode = ImageMode.L
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_GRAY2RGB)
|
||||
elif np_img.shape[2] == 4:
|
||||
# RGBA image with alpha channel
|
||||
img_mode = ImageMode.RGBA
|
||||
alpha = np_img[:, :, 3]
|
||||
np_img = np_img[:, :, 0:3]
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
||||
if esrgan_alpha_upscale:
|
||||
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB)
|
||||
else:
|
||||
img_mode = ImageMode.RGB
|
||||
np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# ------------------- process image (without the alpha channel) ------------------- #
|
||||
self.pre_process(np_img)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_tensor = self.post_process()
|
||||
output_img: npt.NDArray[Any] = output_tensor.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0))
|
||||
if img_mode is ImageMode.L:
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# ------------------- process the alpha channel if necessary ------------------- #
|
||||
if img_mode is ImageMode.RGBA:
|
||||
if esrgan_alpha_upscale:
|
||||
assert alpha is not None
|
||||
self.pre_process(alpha)
|
||||
if self.tile_size > 0:
|
||||
self.tile_process()
|
||||
else:
|
||||
self.process()
|
||||
output_alpha_tensor = self.post_process()
|
||||
output_alpha: npt.NDArray[Any] = output_alpha_tensor.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0))
|
||||
output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY)
|
||||
else: # use the cv2 resize for alpha channel
|
||||
assert alpha is not None
|
||||
h, w = alpha.shape[0:2]
|
||||
output_alpha = cv2.resize(
|
||||
alpha,
|
||||
(w * self.scale, h * self.scale),
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
# merge the alpha channel
|
||||
output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA)
|
||||
output_img[:, :, 3] = output_alpha
|
||||
|
||||
# ------------------------------ return ------------------------------ #
|
||||
if max_range == 65535: # 16-bit image
|
||||
output = (output_img * 65535.0).round().astype(np.uint16)
|
||||
else:
|
||||
output = (output_img * 255.0).round().astype(np.uint8)
|
||||
|
||||
return output
|
@ -71,7 +71,6 @@ dependencies = [
|
||||
"python-multipart",
|
||||
"python-socketio~=5.10.0",
|
||||
"pytorch-lightning",
|
||||
"realesrgan",
|
||||
"requests~=2.28.2",
|
||||
"rich~=13.3",
|
||||
"safetensors~=0.4.0",
|
||||
|
Loading…
x
Reference in New Issue
Block a user