From d868d5d584a2f50efffb004da2a15bdfb50f0166 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 10 Jul 2024 12:25:00 -0400 Subject: [PATCH] Make SpandrelImageToImage tiling much faster. --- .../invocations/spandrel_image_to_image.py | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py index 1591f51bec..788a59f36b 100644 --- a/invokeai/app/invocations/spandrel_image_to_image.py +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -1,6 +1,4 @@ -import numpy as np import torch -from PIL import Image from tqdm import tqdm from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation @@ -16,7 +14,7 @@ from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel -from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.tiles import calc_tiles_min_overlap from invokeai.backend.tiles.utils import TBLR, Tile @@ -50,6 +48,29 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ), ) + def _merge_tiles(self, tiles: list[Tile], tile_tensors: list[torch.Tensor], out_tensor: torch.Tensor): + """A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we + split the overlap in half. No 'blending' is applied. + """ + # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to + # iterate over tiles left-to-right, top-to-bottom. + tiles_and_tensors = list(zip(tiles, tile_tensors, strict=True)) + tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.left) + tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.top) + + for tile, tile_tensor in tiles_and_tensors: + # We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge + # artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems + # unnecessary, but we may find a need in the future. + top_overlap = tile.overlap.top // 2 + left_overlap = tile.overlap.left // 2 + out_tensor[ + :, + :, + tile.coords.top + top_overlap : tile.coords.bottom, + tile.coords.left + left_overlap : tile.coords.right, + ] = tile_tensor[:, :, top_overlap:, left_overlap:] + @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to @@ -100,15 +121,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ) output_tiles.append(output_tile) - # Merge tiles into output image. - np_output_tiles = [np.array(SpandrelImageToImageModel.tensor_to_pil(tile)) for tile in output_tiles] - _, channels, height, width = image_tensor.shape - np_out_image = np.zeros((height * scale, width * scale, channels), dtype=np.uint8) - merge_tiles_with_linear_blending( - dst_image=np_out_image, tiles=scaled_tiles, tile_images=np_output_tiles, blend_amount=min_overlap // 2 + # TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue: + # - Keep the input tensor on the CPU. + # - Move each tile to the GPU as it is processed. + # - Move output tensors back to the CPU as they are produced, and merge them into the output tensor. + + # Merge the tiles to an output tensor. + batch_size, channels, height, width = image_tensor.shape + output_tensor = torch.zeros( + (batch_size, channels, height * scale, width * scale), dtype=image_tensor.dtype, device=image_tensor.device ) + self._merge_tiles(scaled_tiles, output_tiles, output_tensor) # Convert the output tensor to a PIL image. - pil_image = Image.fromarray(np_out_image) + pil_image = SpandrelImageToImageModel.tensor_to_pil(output_tensor) image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto)