diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 557325eaa0..0f699b2d15 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -2,16 +2,16 @@ from pathlib import Path from typing import Literal -import cv2 as cv +import cv2 import numpy as np import torch from basicsr.archs.rrdbnet_arch import RRDBNet from PIL import Image from pydantic import ConfigDict -from realesrgan import RealESRGANer from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin +from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation @@ -92,9 +92,9 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata): esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") - upsampler = RealESRGANer( + upscaler = RealESRGAN( scale=netscale, - model_path=str(models_path / esrgan_model_path), + model_path=models_path / esrgan_model_path, model=rrdbnet_model, half=False, tile=self.tile_size, @@ -102,15 +102,9 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata): # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL # TODO: This strips the alpha... is that okay? - cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR) - - # We can pass an `outscale` value here, but it just resizes the image by that factor after - # upscaling, so it's kinda pointless for our purposes. If you want something other than 4x - # upscaling, you'll need to add a resize node after this one. - upscaled_image, img_mode = upsampler.enhance(cv_image) - - # back to PIL - pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA") + cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) + upscaled_image = upscaler.upscale(cv2_image) + pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") torch.cuda.empty_cache() if choose_torch_device() == torch.device("mps"): diff --git a/invokeai/backend/image_util/realesrgan/LICENSE b/invokeai/backend/image_util/realesrgan/LICENSE new file mode 100644 index 0000000000..552a1eeaf0 --- /dev/null +++ b/invokeai/backend/image_util/realesrgan/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2021, Xintao Wang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/invokeai/backend/image_util/realesrgan/__init__.py b/invokeai/backend/image_util/realesrgan/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/invokeai/backend/image_util/realesrgan/realesrgan.py b/invokeai/backend/image_util/realesrgan/realesrgan.py new file mode 100644 index 0000000000..4d41dabc1e --- /dev/null +++ b/invokeai/backend/image_util/realesrgan/realesrgan.py @@ -0,0 +1,274 @@ +import math +from enum import Enum +from pathlib import Path +from typing import Any, Optional + +import cv2 +import numpy as np +import numpy.typing as npt +import torch +from basicsr.archs.rrdbnet_arch import RRDBNet +from cv2.typing import MatLike +from tqdm import tqdm + +from invokeai.backend.util.devices import choose_torch_device + +""" +Adapted from https://github.com/xinntao/Real-ESRGAN/blob/master/realesrgan/utils.py +License is BSD3, copied to `LICENSE` in this directory. + +The adaptation here has a few changes: +- Remove print statements, use `tqdm` to show progress +- Remove unused "outscale" logic, which simply scales the final image to a given factor +- Remove `dni_weight` logic, which was only used when multiple models were used +- Remove logic to fetch models from network +- Add types, rename a few things +""" + + +class ImageMode(str, Enum): + L = "L" + RGB = "RGB" + RGBA = "RGBA" + + +class RealESRGAN: + """A helper class for upsampling images with RealESRGAN. + + Args: + scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. + model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). + model (nn.Module): The defined network. Default: None. + tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop + input images into tiles, and then process each of them. Finally, they will be merged into one image. + 0 denotes for do not use tile. Default: 0. + tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. + pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. + half (float): Whether to use half precision during inference. Default: False. + """ + + output: torch.Tensor + + def __init__( + self, + scale: int, + model_path: Path, + model: RRDBNet, + tile: int = 0, + tile_pad: int = 10, + pre_pad: int = 10, + half: bool = False, + ) -> None: + self.scale = scale + self.tile_size = tile + self.tile_pad = tile_pad + self.pre_pad = pre_pad + self.mod_scale: Optional[int] = None + self.half = half + self.device = choose_torch_device() + + loadnet = torch.load(model_path, map_location=torch.device("cpu")) + + # prefer to use params_ema + if "params_ema" in loadnet: + keyname = "params_ema" + else: + keyname = "params" + + model.load_state_dict(loadnet[keyname], strict=True) + model.eval() + self.model = model.to(self.device) + + if self.half: + self.model = self.model.half() + + def pre_process(self, img: MatLike) -> None: + """Pre-process, such as pre-pad and mod pad, so that the images can be divisible""" + img_tensor: torch.Tensor = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() + self.img = img_tensor.unsqueeze(0).to(self.device) + if self.half: + self.img = self.img.half() + + # pre_pad + if self.pre_pad != 0: + self.img = torch.nn.functional.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), "reflect") + # mod pad for divisible borders + if self.scale == 2: + self.mod_scale = 2 + elif self.scale == 1: + self.mod_scale = 4 + if self.mod_scale is not None: + self.mod_pad_h, self.mod_pad_w = 0, 0 + _, _, h, w = self.img.size() + if h % self.mod_scale != 0: + self.mod_pad_h = self.mod_scale - h % self.mod_scale + if w % self.mod_scale != 0: + self.mod_pad_w = self.mod_scale - w % self.mod_scale + self.img = torch.nn.functional.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), "reflect") + + def process(self) -> None: + # model inference + self.output = self.model(self.img) + + def tile_process(self) -> None: + """It will first crop input images to tiles, and then process each tile. + Finally, all the processed tiles are merged into one images. + + Modified from: https://github.com/ata4/esrgan-launcher + """ + batch, channel, height, width = self.img.shape + output_height = height * self.scale + output_width = width * self.scale + output_shape = (batch, channel, output_height, output_width) + + # start with black image + self.output = self.img.new_zeros(output_shape) + tiles_x = math.ceil(width / self.tile_size) + tiles_y = math.ceil(height / self.tile_size) + + # loop over all tiles + total_steps = tiles_y * tiles_x + for i in tqdm(range(total_steps), desc="Upscaling"): + y = i // tiles_x + x = i % tiles_x + # extract tile from input image + ofs_x = x * self.tile_size + ofs_y = y * self.tile_size + # input tile area on total image + input_start_x = ofs_x + input_end_x = min(ofs_x + self.tile_size, width) + input_start_y = ofs_y + input_end_y = min(ofs_y + self.tile_size, height) + + # input tile area on total image with padding + input_start_x_pad = max(input_start_x - self.tile_pad, 0) + input_end_x_pad = min(input_end_x + self.tile_pad, width) + input_start_y_pad = max(input_start_y - self.tile_pad, 0) + input_end_y_pad = min(input_end_y + self.tile_pad, height) + + # input tile dimensions + input_tile_width = input_end_x - input_start_x + input_tile_height = input_end_y - input_start_y + input_tile = self.img[ + :, + :, + input_start_y_pad:input_end_y_pad, + input_start_x_pad:input_end_x_pad, + ] + + # upscale tile + with torch.no_grad(): + output_tile = self.model(input_tile) + + # output tile area on total image + output_start_x = input_start_x * self.scale + output_end_x = input_end_x * self.scale + output_start_y = input_start_y * self.scale + output_end_y = input_end_y * self.scale + + # output tile area without padding + output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale + output_end_x_tile = output_start_x_tile + input_tile_width * self.scale + output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale + output_end_y_tile = output_start_y_tile + input_tile_height * self.scale + + # put tile into output image + self.output[:, :, output_start_y:output_end_y, output_start_x:output_end_x] = output_tile[ + :, + :, + output_start_y_tile:output_end_y_tile, + output_start_x_tile:output_end_x_tile, + ] + + def post_process(self) -> torch.Tensor: + # remove extra pad + if self.mod_scale is not None: + _, _, h, w = self.output.size() + self.output = self.output[ + :, + :, + 0 : h - self.mod_pad_h * self.scale, + 0 : w - self.mod_pad_w * self.scale, + ] + # remove prepad + if self.pre_pad != 0: + _, _, h, w = self.output.size() + self.output = self.output[ + :, + :, + 0 : h - self.pre_pad * self.scale, + 0 : w - self.pre_pad * self.scale, + ] + return self.output + + @torch.no_grad() + def upscale(self, img: MatLike, esrgan_alpha_upscale: bool = True) -> npt.NDArray[Any]: + np_img = img.astype(np.float32) + alpha: Optional[np.ndarray] = None + if np.max(np_img) > 256: + # 16-bit image + max_range = 65535 + else: + max_range = 255 + np_img = np_img / max_range + if len(np_img.shape) == 2: + # grayscale image + img_mode = ImageMode.L + np_img = cv2.cvtColor(np_img, cv2.COLOR_GRAY2RGB) + elif np_img.shape[2] == 4: + # RGBA image with alpha channel + img_mode = ImageMode.RGBA + alpha = np_img[:, :, 3] + np_img = np_img[:, :, 0:3] + np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) + if esrgan_alpha_upscale: + alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) + else: + img_mode = ImageMode.RGB + np_img = cv2.cvtColor(np_img, cv2.COLOR_BGR2RGB) + + # ------------------- process image (without the alpha channel) ------------------- # + self.pre_process(np_img) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_tensor = self.post_process() + output_img: npt.NDArray[Any] = output_tensor.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) + if img_mode is ImageMode.L: + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) + + # ------------------- process the alpha channel if necessary ------------------- # + if img_mode is ImageMode.RGBA: + if esrgan_alpha_upscale: + assert alpha is not None + self.pre_process(alpha) + if self.tile_size > 0: + self.tile_process() + else: + self.process() + output_alpha_tensor = self.post_process() + output_alpha: npt.NDArray[Any] = output_alpha_tensor.data.squeeze().float().cpu().clamp_(0, 1).numpy() + output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) + output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) + else: # use the cv2 resize for alpha channel + assert alpha is not None + h, w = alpha.shape[0:2] + output_alpha = cv2.resize( + alpha, + (w * self.scale, h * self.scale), + interpolation=cv2.INTER_LINEAR, + ) + + # merge the alpha channel + output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) + output_img[:, :, 3] = output_alpha + + # ------------------------------ return ------------------------------ # + if max_range == 65535: # 16-bit image + output = (output_img * 65535.0).round().astype(np.uint16) + else: + output = (output_img * 255.0).round().astype(np.uint8) + + return output diff --git a/pyproject.toml b/pyproject.toml index 80e5413530..b80b4658dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ classifiers = [ dependencies = [ "accelerate~=0.24.0", "albumentations", + "basicsr", "click", "clip_anytorch", # replacing "clip @ https://github.com/openai/CLIP/archive/eaa22acb90a5876642d0507623e859909230a52d.zip", "compel~=2.0.2", @@ -71,7 +72,6 @@ dependencies = [ "python-multipart", "python-socketio~=5.10.0", "pytorch-lightning", - "realesrgan", "requests~=2.28.2", "rich~=13.3", "safetensors~=0.4.0",