feat(nodes): clear torch cache after upscaling

This can use many GB of VRAM, so we need to clean up after ourselves.
This commit is contained in:
psychedelicious 2023-10-04 15:23:31 +11:00
parent 010c8e8038
commit 23a06fd06d

View File

@ -7,9 +7,11 @@ import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from realesrgan import RealESRGANer
import torch
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@ -22,6 +24,9 @@ ESRGAN_MODELS = Literal[
"RealESRGAN_x2plus.pth",
]
if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.1")
class ESRGANInvocation(BaseInvocation):
@ -30,7 +35,7 @@ class ESRGANInvocation(BaseInvocation):
image: ImageField = InputField(description="The input image")
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
tile_size: int = InputField(
default=512, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
)
def invoke(self, context: InvocationContext) -> ImageOutput:
@ -93,6 +98,7 @@ class ESRGANInvocation(BaseInvocation):
)
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after
@ -103,6 +109,10 @@ class ESRGANInvocation(BaseInvocation):
# back to PIL
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
image_dto = context.services.images.create(
image=pil_image,
image_origin=ResourceOrigin.INTERNAL,