feat(nodes): remove dependency on realesrgan

We used the `RealESRGANer` utility class from the repo. It handled model loading and tiled upscaling logic.

Unfortunately, it hasn't been updated in over a year, had no types, and annoyingly printed to console.

I've adapted the class, cleaning it up a bit and removing the bits that are not relevant for us.

Upscaling functionality is identical.
This commit is contained in:
psychedelicious
2023-11-27 21:22:31 +11:00
parent 84629df49c
commit 2192210910
3 changed files with 279 additions and 7 deletions

View File

@ -2,16 +2,16 @@
from pathlib import Path
from typing import Literal
import cv2 as cv
import cv2
import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import ConfigDict
from realesrgan import RealESRGANer
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.backend.image_util.esrgan import RealESRGANer
from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, WithMetadata, WithWorkflow, invocation
@ -94,7 +94,7 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
upsampler = RealESRGANer(
scale=netscale,
model_path=str(models_path / esrgan_model_path),
model_path=models_path / esrgan_model_path,
model=rrdbnet_model,
half=False,
tile=self.tile_size,
@ -102,15 +102,15 @@ class ESRGANInvocation(BaseInvocation, WithWorkflow, WithMetadata):
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
cv2_image = cv2.cvtColor(np.array(image.convert("RGB")), cv2.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
# upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
# upscaling, you'll need to add a resize node after this one.
upscaled_image, img_mode = upsampler.enhance(cv_image)
upscaled_image = upsampler.enhance(cv2_image)
# back to PIL
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
pil_image = Image.fromarray(cv2.cvtColor(upscaled_image, cv2.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):