feat(nodes): add RealESRGAN_x2plus.pth, update upscale nodes

- add `RealESRGAN_x2plus.pth` model to installer
- add `RealESRGAN_x2plus.pth` to `realesrgan` node
- rename `RealESRGAN` to `ESRGAN` in nodes
- make `scale_factor` optional in `img_scale` node
This commit is contained in:
psychedelicious 2023-07-17 21:00:22 +10:00
parent 99383c2701
commit 56098f370c
3 changed files with 26 additions and 22 deletions

View File

@ -437,8 +437,8 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
type: Literal["img_scale"] = "img_scale" type: Literal["img_scale"] = "img_scale"
# Inputs # Inputs
image: Optional[ImageField] = Field(default=None, description="The image to scale") image: Optional[ImageField] = Field(default=None, description="The image to scale")
scale_factor: float = Field(gt=0, description="The factor by which to scale the image") scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode") resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
# fmt: on # fmt: on

View File

@ -1,6 +1,6 @@
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path, PosixPath from pathlib import Path
from typing import Literal, Union, cast from typing import Literal, Union
import cv2 as cv import cv2 as cv
import numpy as np import numpy as np
@ -16,19 +16,20 @@ from .image import ImageOutput
# TODO: Populate this from disk? # TODO: Populate this from disk?
# TODO: Use model manager to load? # TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[ ESRGAN_MODELS = Literal[
"RealESRGAN_x4plus.pth", "RealESRGAN_x4plus.pth",
"RealESRGAN_x4plus_anime_6B.pth", "RealESRGAN_x4plus_anime_6B.pth",
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
"RealESRGAN_x2plus.pth",
] ]
class RealESRGANInvocation(BaseInvocation): class ESRGANInvocation(BaseInvocation):
"""Upscales an image using RealESRGAN.""" """Upscales an image using RealESRGAN."""
type: Literal["realesrgan"] = "realesrgan" type: Literal["esrgan"] = "esrgan"
image: Union[ImageField, None] = Field(default=None, description="The input image") image: Union[ImageField, None] = Field(default=None, description="The input image")
model_name: REALESRGAN_MODELS = Field( model_name: ESRGAN_MODELS = Field(
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use" default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
) )
@ -65,19 +66,17 @@ class RealESRGANInvocation(BaseInvocation):
scale=4, scale=4,
) )
netscale = 4 netscale = 4
# TODO: add x2 models handling? elif self.model_name in ["RealESRGAN_x2plus.pth"]:
# elif self.model_name in ["RealESRGAN_x2plus"]: # x2 RRDBNet model
# # x2 RRDBNet model rrdbnet_model = RRDBNet(
# model = RRDBNet( num_in_ch=3,
# num_in_ch=3, num_out_ch=3,
# num_out_ch=3, num_feat=64,
# num_feat=64, num_block=23,
# num_block=23, num_grow_ch=32,
# num_grow_ch=32, scale=2,
# scale=2, )
# ) netscale = 2
# model_path = Path()
# netscale = 2
else: else:
msg = f"Invalid RealESRGAN model: {self.model_name}" msg = f"Invalid RealESRGAN model: {self.model_name}"
context.services.logger.error(msg) context.services.logger.error(msg)

View File

@ -223,7 +223,7 @@ def download_conversion_models():
# --------------------------------------------- # ---------------------------------------------
def download_realesrgan(): def download_realesrgan():
logger.info("Installing RealESRGAN models...") logger.info("Installing ESRGAN Upscaling models...")
URLs = [ URLs = [
dict( dict(
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
@ -240,6 +240,11 @@ def download_realesrgan():
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth", dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
description = "ESRGAN_SRx4_DF2KOST_official.pth", description = "ESRGAN_SRx4_DF2KOST_official.pth",
), ),
dict(
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
description = "RealESRGAN_x2plus.pth",
),
] ]
for model in URLs: for model in URLs:
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description']) download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])