mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): add RealESRGAN_x2plus.pth
, update upscale nodes
- add `RealESRGAN_x2plus.pth` model to installer - add `RealESRGAN_x2plus.pth` to `realesrgan` node - rename `RealESRGAN` to `ESRGAN` in nodes - make `scale_factor` optional in `img_scale` node
This commit is contained in:
parent
99383c2701
commit
56098f370c
@ -437,8 +437,8 @@ class ImageScaleInvocation(BaseInvocation, PILInvocationConfig):
|
|||||||
type: Literal["img_scale"] = "img_scale"
|
type: Literal["img_scale"] = "img_scale"
|
||||||
|
|
||||||
# Inputs
|
# Inputs
|
||||||
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
image: Optional[ImageField] = Field(default=None, description="The image to scale")
|
||||||
scale_factor: float = Field(gt=0, description="The factor by which to scale the image")
|
scale_factor: Optional[float] = Field(default=2.0, gt=0, description="The factor by which to scale the image")
|
||||||
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
resample_mode: PIL_RESAMPLING_MODES = Field(default="bicubic", description="The resampling mode")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
|
||||||
from pathlib import Path, PosixPath
|
from pathlib import Path
|
||||||
from typing import Literal, Union, cast
|
from typing import Literal, Union
|
||||||
|
|
||||||
import cv2 as cv
|
import cv2 as cv
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -16,19 +16,20 @@ from .image import ImageOutput
|
|||||||
|
|
||||||
# TODO: Populate this from disk?
|
# TODO: Populate this from disk?
|
||||||
# TODO: Use model manager to load?
|
# TODO: Use model manager to load?
|
||||||
REALESRGAN_MODELS = Literal[
|
ESRGAN_MODELS = Literal[
|
||||||
"RealESRGAN_x4plus.pth",
|
"RealESRGAN_x4plus.pth",
|
||||||
"RealESRGAN_x4plus_anime_6B.pth",
|
"RealESRGAN_x4plus_anime_6B.pth",
|
||||||
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
"ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
|
"RealESRGAN_x2plus.pth",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class RealESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation):
|
||||||
"""Upscales an image using RealESRGAN."""
|
"""Upscales an image using RealESRGAN."""
|
||||||
|
|
||||||
type: Literal["realesrgan"] = "realesrgan"
|
type: Literal["esrgan"] = "esrgan"
|
||||||
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
image: Union[ImageField, None] = Field(default=None, description="The input image")
|
||||||
model_name: REALESRGAN_MODELS = Field(
|
model_name: ESRGAN_MODELS = Field(
|
||||||
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -65,19 +66,17 @@ class RealESRGANInvocation(BaseInvocation):
|
|||||||
scale=4,
|
scale=4,
|
||||||
)
|
)
|
||||||
netscale = 4
|
netscale = 4
|
||||||
# TODO: add x2 models handling?
|
elif self.model_name in ["RealESRGAN_x2plus.pth"]:
|
||||||
# elif self.model_name in ["RealESRGAN_x2plus"]:
|
# x2 RRDBNet model
|
||||||
# # x2 RRDBNet model
|
rrdbnet_model = RRDBNet(
|
||||||
# model = RRDBNet(
|
num_in_ch=3,
|
||||||
# num_in_ch=3,
|
num_out_ch=3,
|
||||||
# num_out_ch=3,
|
num_feat=64,
|
||||||
# num_feat=64,
|
num_block=23,
|
||||||
# num_block=23,
|
num_grow_ch=32,
|
||||||
# num_grow_ch=32,
|
scale=2,
|
||||||
# scale=2,
|
)
|
||||||
# )
|
netscale = 2
|
||||||
# model_path = Path()
|
|
||||||
# netscale = 2
|
|
||||||
else:
|
else:
|
||||||
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
msg = f"Invalid RealESRGAN model: {self.model_name}"
|
||||||
context.services.logger.error(msg)
|
context.services.logger.error(msg)
|
||||||
|
@ -223,7 +223,7 @@ def download_conversion_models():
|
|||||||
|
|
||||||
# ---------------------------------------------
|
# ---------------------------------------------
|
||||||
def download_realesrgan():
|
def download_realesrgan():
|
||||||
logger.info("Installing RealESRGAN models...")
|
logger.info("Installing ESRGAN Upscaling models...")
|
||||||
URLs = [
|
URLs = [
|
||||||
dict(
|
dict(
|
||||||
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
|
||||||
@ -240,6 +240,11 @@ def download_realesrgan():
|
|||||||
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
dest= "core/upscaling/realesrgan/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
|
||||||
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
description = "ESRGAN_SRx4_DF2KOST_official.pth",
|
||||||
),
|
),
|
||||||
|
dict(
|
||||||
|
url= "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
|
||||||
|
dest= "core/upscaling/realesrgan/RealESRGAN_x2plus.pth",
|
||||||
|
description = "RealESRGAN_x2plus.pth",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
for model in URLs:
|
for model in URLs:
|
||||||
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
download_with_progress_bar(model['url'], config.models_path / model['dest'], model['description'])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user