feat(nodes): clear torch cache after upscaling

This can use many GB of VRAM, so we need to clean up after ourselves.
This commit is contained in:
psychedelicious
2023-10-04 15:23:31 +11:00
parent 010c8e8038
commit 23a06fd06d

View File

@ -7,9 +7,11 @@ import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image from PIL import Image
from realesrgan import RealESRGANer from realesrgan import RealESRGANer
import torch
from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.invocations.primitives import ImageField, ImageOutput
from invokeai.app.models.image import ImageCategory, ResourceOrigin from invokeai.app.models.image import ImageCategory, ResourceOrigin
from invokeai.backend.util.devices import choose_torch_device
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
@ -22,6 +24,9 @@ ESRGAN_MODELS = Literal[
"RealESRGAN_x2plus.pth", "RealESRGAN_x2plus.pth",
] ]
if choose_torch_device() == torch.device("mps"):
from torch import mps
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.1") @invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.1")
class ESRGANInvocation(BaseInvocation): class ESRGANInvocation(BaseInvocation):
@ -30,7 +35,7 @@ class ESRGANInvocation(BaseInvocation):
image: ImageField = InputField(description="The input image") image: ImageField = InputField(description="The input image")
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use") model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
tile_size: int = InputField( tile_size: int = InputField(
default=512, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)" default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
) )
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
@ -93,6 +98,7 @@ class ESRGANInvocation(BaseInvocation):
) )
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
# TODO: This strips the alpha... is that okay?
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR) cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
# We can pass an `outscale` value here, but it just resizes the image by that factor after # We can pass an `outscale` value here, but it just resizes the image by that factor after
@ -103,6 +109,10 @@ class ESRGANInvocation(BaseInvocation):
# back to PIL # back to PIL
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA") pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
image_dto = context.services.images.create( image_dto = context.services.images.create(
image=pil_image, image=pil_image,
image_origin=ResourceOrigin.INTERNAL, image_origin=ResourceOrigin.INTERNAL,