# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team from pathlib import Path from typing import Literal, Union, cast import cv2 as cv import numpy as np from basicsr.archs.rrdbnet_arch import RRDBNet from PIL import Image from pydantic import Field from realesrgan import RealESRGANer from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageOutput # TODO: Populate this from disk? # TODO: Use model manager to load? REALESRGAN_MODELS = Literal[ "RealESRGAN_x4plus.pth", "RealESRGAN_x4plus_anime_6B.pth", "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", ] class RealESRGANInvocation(BaseInvocation): """Upscales an image using RealESRGAN.""" type: Literal["realesrgan"] = "realesrgan" image: Union[ImageField, None] = Field(default=None, description="The input image") model_name: REALESRGAN_MODELS = Field( default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use" ) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) # type: ignore models_dir = cast(Path, context.services.configuration.root_dir) / Path("models/") # type: ignore rrdbnet_model = None netscale = None model_path = None if self.model_name in [ "RealESRGAN_x4plus.pth", "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", ]: # x4 RRDBNet model rrdbnet_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4, ) netscale = 4 elif self.model_name == "RealESRGAN_x4plus_anime_6B.pth": # x4 RRDBNet model, 6 blocks rrdbnet_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, # 6 blocks num_grow_ch=32, scale=4, ) netscale = 4 # TODO: add x2 models handling? # elif self.model_name in ["RealESRGAN_x2plus"]: # # x2 RRDBNet model # model = RRDBNet( # num_in_ch=3, # num_out_ch=3, # num_feat=64, # num_block=23, # num_grow_ch=32, # scale=2, # ) # model_path = Path() # netscale = 2 else: msg = f"Invalid RealESRGAN model: {self.model_name}" context.services.logger.error(msg) raise ValueError(msg) model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") upsampler = RealESRGANer( scale=netscale, model_path=str(models_dir / model_path), model=rrdbnet_model, half=False, ) # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR) # We can pass an `outscale` value here, but it just resizes the image by that factor after # upscaling, so it's kinda pointless for our purposes. If you want something other than 4x # upscaling, you'll need to add a resize node after this one. upscaled_image, img_mode = upsampler.enhance(cv_image) # back to PIL pil_image = Image.fromarray( cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB) ).convert("RGBA") image_dto = context.services.images.create( image=pil_image, image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, is_intermediate=self.is_intermediate, ) return ImageOutput( image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, )