Add tiling to SpandrelImageToImageInvocation (#6594)

## Summary

Add tiling to the `SpandrelImageToImageInvocation` node so that it can
process large images.

Tiling enables this node to run on effectively any input image
dimension. Of course, the computation time increases quadratically with
the image dimension.

Some profiling results on an RTX4090:
- Input 1024x1024, 4x upscale, 4x UltraSharp ESRGAN: `13 secs`, `<4 GB
VRAM`
- Input 4096x4096, 4x upscale, 4x UltraSharop ESRGAN: `46 secs`, `<4 GB
VRAM`
- Input 4096x4096, 2x upscale, SwinIR: `165 secs`, `<5 GB VRAM`

A lot of the time is spent PNG encoding the final image:
- PNG encoding of a 16384x16384 image takes `83secs @
pil_compress_level=7`, `24secs @ pil_compress_level=1`

Callout: If we want to start building workflows that pass large images
between nodes, we are going to have to find a way to avoid the PNG
encode/decode roundtrip that we are currently doing. As is, we will be
incurring a huge penalty for every node that receives/produces a large
image.

## QA Instructions

- [x] Tested with tiling up to 4096x4096 -> 16384x16384.
- [x] Test on images with an alpha channel (the alpha channel is
dropped).
- [x] Test on images with odd dimension.
- [x] Test no tiling (`tile_size=0`)

## Merge Plan

- [x] Merge #6556 first, and change the target branch to `main`.

## Checklist

- [x] _The PR has a short but descriptive title, suitable for a
changelog_
- [ ] _Tests added / updated (if applicable)_
- [x] _Documentation added / updated (if applicable)_
This commit is contained in:
Ryan Dick 2024-07-16 15:51:15 -04:00 committed by GitHub
commit 95e9f5323b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 107 additions and 7 deletions

View File

@ -1,4 +1,7 @@
import numpy as np
import torch import torch
from PIL import Image
from tqdm import tqdm
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
@ -11,11 +14,14 @@ from invokeai.app.invocations.fields import (
) )
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.session_processor.session_processor_common import CanceledException
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0") @invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0")
class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
"""Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel).""" """Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel)."""
@ -25,25 +31,114 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard):
description=FieldDescriptions.spandrel_image_to_image_model, description=FieldDescriptions.spandrel_image_to_image_model,
ui_type=UIType.SpandrelImageToImageModel, ui_type=UIType.SpandrelImageToImageModel,
) )
tile_size: int = InputField(
default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling."
)
def _scale_tile(self, tile: Tile, scale: int) -> Tile:
return Tile(
coords=TBLR(
top=tile.coords.top * scale,
bottom=tile.coords.bottom * scale,
left=tile.coords.left * scale,
right=tile.coords.right * scale,
),
overlap=TBLR(
top=tile.overlap.top * scale,
bottom=tile.overlap.bottom * scale,
left=tile.overlap.left * scale,
right=tile.overlap.right * scale,
),
)
@torch.inference_mode() @torch.inference_mode()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
image = context.images.get_pil(self.image.image_name) # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to
# revisit this.
image = context.images.get_pil(self.image.image_name, mode="RGB")
# Compute the image tiles.
if self.tile_size > 0:
min_overlap = 20
tiles = calc_tiles_min_overlap(
image_height=image.height,
image_width=image.width,
tile_height=self.tile_size,
tile_width=self.tile_size,
min_overlap=min_overlap,
)
else:
# No tiling. Generate a single tile that covers the entire image.
min_overlap = 0
tiles = [
Tile(
coords=TBLR(top=0, bottom=image.height, left=0, right=image.width),
overlap=TBLR(top=0, bottom=0, left=0, right=0),
)
]
# Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate
# over tiles left-to-right, top-to-bottom.
tiles = sorted(tiles, key=lambda x: x.coords.left)
tiles = sorted(tiles, key=lambda x: x.coords.top)
# Prepare input image for inference.
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image)
# Load the model. # Load the model.
spandrel_model_info = context.models.load(self.image_to_image_model) spandrel_model_info = context.models.load(self.image_to_image_model)
# Run the model on each tile.
with spandrel_model_info as spandrel_model: with spandrel_model_info as spandrel_model:
assert isinstance(spandrel_model, SpandrelImageToImageModel) assert isinstance(spandrel_model, SpandrelImageToImageModel)
# Prepare input image for inference. # Scale the tiles for re-assembling the final image.
image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) scale = spandrel_model.scale
scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles]
# Prepare the output tensor.
_, channels, height, width = image_tensor.shape
output_tensor = torch.zeros(
(height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu")
)
image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run inference. for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"):
image_tensor = spandrel_model.run(image_tensor) # Exit early if the invocation has been canceled.
if context.util.is_canceled():
raise CanceledException
# Extract the current tile from the input tensor.
input_tile = image_tensor[
:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right
].to(device=spandrel_model.device, dtype=spandrel_model.dtype)
# Run the model on the tile.
output_tile = spandrel_model.run(input_tile)
# Convert the output tile into the output tensor's format.
# (N, C, H, W) -> (C, H, W)
output_tile = output_tile.squeeze(0)
# (C, H, W) -> (H, W, C)
output_tile = output_tile.permute(1, 2, 0)
output_tile = output_tile.clamp(0, 1)
output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu"))
# Merge the output tile into the output tensor.
# We only keep half of the overlap on the top and left side of the tile. We do this in case there are
# edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers
# it seems unnecessary, but we may find a need in the future.
top_overlap = scaled_tile.overlap.top // 2
left_overlap = scaled_tile.overlap.left // 2
output_tensor[
scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom,
scaled_tile.coords.left + left_overlap : scaled_tile.coords.right,
:,
] = output_tile[top_overlap:, left_overlap:, :]
# Convert the output tensor to a PIL image. # Convert the output tensor to a PIL image.
pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor) np_image = output_tensor.detach().numpy().astype(np.uint8)
pil_image = Image.fromarray(np_image)
image_dto = context.images.save(image=pil_image) image_dto = context.images.save(image=pil_image)
return ImageOutput.build(image_dto) return ImageOutput.build(image_dto)

View File

@ -126,6 +126,11 @@ class SpandrelImageToImageModel(RawModel):
"""The dtype of the underlying model.""" """The dtype of the underlying model."""
return self._spandrel_model.dtype return self._spandrel_model.dtype
@property
def scale(self) -> int:
"""The scale of the model (e.g. 1x, 2x, 4x, etc.)."""
return self._spandrel_model.scale
def calc_size(self) -> int: def calc_size(self) -> int:
"""Get size of the model in memory in bytes.""" """Get size of the model in memory in bytes."""
# HACK(ryand): Fix this issue with circular imports. # HACK(ryand): Fix this issue with circular imports.