mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(nodes): clear torch cache after upscaling
This can use many GB of VRAM, so we need to clean up after ourselves.
This commit is contained in:
@ -7,9 +7,11 @@ import numpy as np
|
|||||||
from basicsr.archs.rrdbnet_arch import RRDBNet
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from realesrgan import RealESRGANer
|
from realesrgan import RealESRGANer
|
||||||
|
import torch
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
from invokeai.app.models.image import ImageCategory, ResourceOrigin
|
||||||
|
from invokeai.backend.util.devices import choose_torch_device
|
||||||
|
|
||||||
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
from .baseinvocation import BaseInvocation, InputField, InvocationContext, invocation
|
||||||
|
|
||||||
@ -22,6 +24,9 @@ ESRGAN_MODELS = Literal[
|
|||||||
"RealESRGAN_x2plus.pth",
|
"RealESRGAN_x2plus.pth",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if choose_torch_device() == torch.device("mps"):
|
||||||
|
from torch import mps
|
||||||
|
|
||||||
|
|
||||||
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.1")
|
@invocation("esrgan", title="Upscale (RealESRGAN)", tags=["esrgan", "upscale"], category="esrgan", version="1.0.1")
|
||||||
class ESRGANInvocation(BaseInvocation):
|
class ESRGANInvocation(BaseInvocation):
|
||||||
@ -30,7 +35,7 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
image: ImageField = InputField(description="The input image")
|
image: ImageField = InputField(description="The input image")
|
||||||
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
model_name: ESRGAN_MODELS = InputField(default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use")
|
||||||
tile_size: int = InputField(
|
tile_size: int = InputField(
|
||||||
default=512, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
default=400, ge=0, description="Tile size for tiled ESRGAN upscaling (0=tiling disabled)"
|
||||||
)
|
)
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ImageOutput:
|
def invoke(self, context: InvocationContext) -> ImageOutput:
|
||||||
@ -93,6 +98,7 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
# prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
|
||||||
|
# TODO: This strips the alpha... is that okay?
|
||||||
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)
|
||||||
|
|
||||||
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
# We can pass an `outscale` value here, but it just resizes the image by that factor after
|
||||||
@ -103,6 +109,10 @@ class ESRGANInvocation(BaseInvocation):
|
|||||||
# back to PIL
|
# back to PIL
|
||||||
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
pil_image = Image.fromarray(cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)).convert("RGBA")
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
if choose_torch_device() == torch.device("mps"):
|
||||||
|
mps.empty_cache()
|
||||||
|
|
||||||
image_dto = context.services.images.create(
|
image_dto = context.services.images.create(
|
||||||
image=pil_image,
|
image=pil_image,
|
||||||
image_origin=ResourceOrigin.INTERNAL,
|
image_origin=ResourceOrigin.INTERNAL,
|
||||||
|
Reference in New Issue
Block a user