# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team from pathlib import Path from typing import Literal import cv2 import numpy as np import torch from basicsr.archs.rrdbnet_arch import RRDBNet from PIL import Image from pydantic import ConfigDict from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.backend.image_util.realesrgan.realesrgan import RealESRGAN from invokeai.backend.util.devices import choose_torch_device from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, invocation # TODO: Populate this from disk? # TODO: Use model manager to load? ESRGAN_MODELS = Literal[ "RealESRGAN_x4plus.pth", "RealESRGAN_x4plus_anime_6B.pth", "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", "RealESRGAN_x2plus.pth", ] if choose_torch_device() == torch.device("mps"): from torch import mps @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.3.0") class ESRGANInvocation(BaseInvocation, WithMetadata): """Upscales an image using RealESRGAN.""" image: ImageField = InputField(description="The input image") model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use") tile_size: int = InputField( default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)" ) model_config = ConfigDict(protected_namespaces=()) def invoke(self, context: InvocationContext) -> ImageOutput: image = context.services.images.get_pil_image(self.image.image_name) models_path = context.services.configuration.models_path rrdbnet_model = None netscale = None esrgan_model_path = None if self.model_name in [ "RealESRGAN_x4plus.pth", "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", ]: # x4 RRDBNet model rrdbnet_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4, ) netscale = 4 elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]: # x4 RRDBNet model, 6 blocks rrdbnet_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, # 6 blocks num_grow_ch=32, scale=4, ) netscale = 4 elif self.model_name in ["RealESRGAN_x2plus.pth"]: # x2 RRDBNet model rrdbnet_model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2, ) netscale = 2 else: msg = f"Invalid RealESRGAN model: {self.model_name}" context.services.logger.error(msg) raise ValueError(msg) esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") upscaler = RealESRGAN( scale=netscale, model_path=models_path / esrgan_model_path, model=rrdbnet_model, half=False, tile=self.tile_size, ) # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL # TODO: This strips the alpha... is that okay? cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR) upscaled_image = upscaler.upscale(cv2_image) pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA") torch.cuda.empty_cache() if choose_torch_device() == torch.device("mps"): mps.empty_cache() image_dto = context.services.images.create( image=pil_image, image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id, session_id=context.graph_execution_state_id, is_intermediate=self.is_intermediate, metadata=self.metadata, workflow=context.workflow, ) return ImageOutput( image=ImageField(image_name=image_dto.image_name), width=image_dto.width, height=image_dto.height, )