diff --git a/invokeai/app/invocations/upscale.py b/invokeai/app/invocations/upscale.py index 15497343ae..f320e010e2 100644 --- a/invokeai/app/invocations/upscale.py +++ b/invokeai/app/invocations/upscale.py @@ -1,48 +1,112 @@ -# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) - -from typing import Literal, Optional +# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team +from pathlib import Path +from typing import Literal, Union, cast +import cv2 as cv +import numpy as np +from basicsr.archs.rrdbnet_arch import RRDBNet +from PIL import Image from pydantic import Field +from realesrgan import RealESRGANer from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin -from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig + +from .baseinvocation import BaseInvocation, InvocationContext from .image import ImageOutput +# TODO: Populate this from disk? +# TODO: Use model manager to load? +REALESRGAN_MODELS = Literal[ + "RealESRGAN_x4plus.pth", + "RealESRGAN_x4plus_anime_6B.pth", + "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", +] -class UpscaleInvocation(BaseInvocation): - """Upscales an image.""" - # fmt: off - type: Literal["upscale"] = "upscale" +class RealESRGANInvocation(BaseInvocation): + """Upscales an image using RealESRGAN.""" - # Inputs - image: Optional[ImageField] = Field(description="The input image", default=None) - strength: float = Field(default=0.75, gt=0, le=1, description="The strength") - level: Literal[2, 4] = Field(default=2, description="The upscale level") - # fmt: on - - # Schema customisation - class Config(InvocationConfig): - schema_extra = { - "ui": { - "tags": ["upscaling", "image"], - }, - } + type: Literal["realesrgan"] = "realesrgan" + image: Union[ImageField, None] = Field(default=None, description="The input image") + model_name: REALESRGAN_MODELS = Field( + default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use" + ) def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.services.images.get_pil_image(self.image.image_name) - results = context.services.restoration.upscale_and_reconstruct( - image_list=[[image, 0]], - upscale=(self.level, self.strength), - strength=0.0, # GFPGAN strength - save_original=False, - image_callback=None, + image = context.services.images.get_pil_image(self.image.image_name) # type: ignore + models_dir = cast(Path, context.services.configuration.root_dir) / Path("models/") # type: ignore + + rrdbnet_model = None + netscale = None + model_path = None + + if self.model_name in [ + "RealESRGAN_x4plus.pth", + "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", + ]: + # x4 RRDBNet model + rrdbnet_model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=23, + num_grow_ch=32, + scale=4, + ) + netscale = 4 + elif self.model_name == "RealESRGAN_x4plus_anime_6B.pth": + # x4 RRDBNet model, 6 blocks + rrdbnet_model = RRDBNet( + num_in_ch=3, + num_out_ch=3, + num_feat=64, + num_block=6, # 6 blocks + num_grow_ch=32, + scale=4, + ) + netscale = 4 + # TODO: add x2 models handling? + # elif self.model_name in ["RealESRGAN_x2plus"]: + # # x2 RRDBNet model + # model = RRDBNet( + # num_in_ch=3, + # num_out_ch=3, + # num_feat=64, + # num_block=23, + # num_grow_ch=32, + # scale=2, + # ) + # model_path = Path() + # netscale = 2 + else: + msg = f"Invalid RealESRGAN model: {self.model_name}" + context.services.logger.error(msg) + raise ValueError(msg) + + model_path = Path(f"core/upscaling/realesrgan/{self.model_name}") + + upsampler = RealESRGANer( + scale=netscale, + model_path=str(models_dir / model_path), + model=rrdbnet_model, + half=False, ) - # Results are image and seed, unwrap for now - # TODO: can this return multiple results? + # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL + cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR) + + # We can pass an `outscale` value here, but it just resizes the image by that factor after + # upscaling, so it's kinda pointless for our purposes. If you want something other than 4x + # upscaling, you'll need to add a resize node after this one. + upscaled_image, img_mode = upsampler.enhance(cv_image) + + # back to PIL + pil_image = Image.fromarray( + cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB) + ).convert("RGBA") + image_dto = context.services.images.create( - image=results[0][0], + image=pil_image, image_origin=ResourceOrigin.INTERNAL, image_category=ImageCategory.GENERAL, node_id=self.id,