feat(nodes): add realesrgan node

This commit is contained in:
psychedelicious 2023-07-16 01:06:50 +10:00
parent 32e7e52d69
commit 74ca87ac9e

View File

@ -1,48 +1,112 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path
from typing import Literal, Optional from typing import Literal, Union, cast
import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import Field from pydantic import Field
from realesrgan import RealESRGANer
from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin
from .baseinvocation import BaseInvocation, InvocationContext, InvocationConfig
from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput from .image import ImageOutput
# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]
class UpscaleInvocation(BaseInvocation):
"""Upscales an image."""
# fmt: off class RealESRGANInvocation(BaseInvocation):
type: Literal["upscale"] = "upscale" """Upscales an image using RealESRGAN."""
# Inputs type: Literal["realesrgan"] = "realesrgan"
image: Optional[ImageField] = Field(description="The input image", default=None) image: Union[ImageField, None] = Field(default=None, description="The input image")
strength: float = Field(default=0.75, gt=0, le=1, description="The strength") model_name: REALESRGAN_MODELS = Field(
level: Literal[2, 4] = Field(default=2, description="The upscale level") default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
# fmt: on )
# Schema customisation
class Config(InvocationConfig):
schema_extra = {
"ui": {
"tags": ["upscaling", "image"],
},
}
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.services.images.get_pil_image(self.image.image_name) image = context.services.images.get_pil_image(self.image.image_name) # type: ignore
results = context.services.restoration.upscale_and_reconstruct( models_dir = cast(Path, context.services.configuration.root_dir) / Path("models/") # type: ignore
image_list=[[image, 0]],
upscale=(self.level, self.strength), rrdbnet_model = None
strength=0.0, # GFPGAN strength netscale = None
save_original=False, model_path = None
image_callback=None,
if self.model_name in [
"RealESRGAN_x4plus.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]:
# x4 RRDBNet model
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4,
)
netscale = 4
elif self.model_name == "RealESRGAN_x4plus_anime_6B.pth":
# x4 RRDBNet model, 6 blocks
rrdbnet_model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=6, # 6 blocks
num_grow_ch=32,
scale=4,
)
netscale = 4
# TODO: add x2 models handling?
# elif self.model_name in ["RealESRGAN_x2plus"]:
# # x2 RRDBNet model
# model = RRDBNet(
# num_in_ch=3,
# num_out_ch=3,
# num_feat=64,
# num_block=23,
# num_grow_ch=32,
# scale=2,
# )
# model_path = Path()
# netscale = 2
else:
msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg)
raise ValueError(msg)
model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")
upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_dir / model_path),
model=rrdbnet_model,
half=False,
) )
# Results are image and seed, unwrap for now # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: can this return multiple results? cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
# upscaling, you'll need to add a resize node after this one.
upscaled_image, img_mode = upsampler.enhance(cv_image)
# back to PIL
pil_image = Image.fromarray(
cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
).convert("RGBA")
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=results[0][0], image=pil_image,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,
image_category=ImageCategory.GENERAL, image_category=ImageCategory.GENERAL,
node_id=self.id, node_id=self.id,