From ab775726b7b61ce06142a1e9f2546d5528829ba5 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 9 Jul 2024 17:52:28 -0400 Subject: [PATCH 1/4] Add tiling support to the SpoandrelImageToImage node. --- .../invocations/spandrel_image_to_image.py | 79 +++++++++++++++++-- .../backend/spandrel_image_to_image_model.py | 5 ++ 2 files changed, 77 insertions(+), 7 deletions(-) diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py index 76cf31480c..1591f51bec 100644 --- a/invokeai/app/invocations/spandrel_image_to_image.py +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -1,4 +1,7 @@ +import numpy as np import torch +from PIL import Image +from tqdm import tqdm from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ( @@ -13,9 +16,11 @@ from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel +from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.utils import TBLR, Tile -@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.0.0") +@invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.1.0") class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel).""" @@ -25,25 +30,85 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): description=FieldDescriptions.spandrel_image_to_image_model, ui_type=UIType.SpandrelImageToImageModel, ) + tile_size: int = InputField( + default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling." + ) + + def _scale_tile(self, tile: Tile, scale: int) -> Tile: + return Tile( + coords=TBLR( + top=tile.coords.top * scale, + bottom=tile.coords.bottom * scale, + left=tile.coords.left * scale, + right=tile.coords.right * scale, + ), + overlap=TBLR( + top=tile.overlap.top * scale, + bottom=tile.overlap.bottom * scale, + left=tile.overlap.left * scale, + right=tile.overlap.right * scale, + ), + ) @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: - image = context.images.get_pil(self.image.image_name) + # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to + # revisit this. + image = context.images.get_pil(self.image.image_name, mode="RGB") + + # Compute the image tiles. + if self.tile_size > 0: + min_overlap = 20 + tiles = calc_tiles_min_overlap( + image_height=image.height, + image_width=image.width, + tile_height=self.tile_size, + tile_width=self.tile_size, + min_overlap=min_overlap, + ) + else: + # No tiling. Generate a single tile that covers the entire image. + min_overlap = 0 + tiles = [ + Tile( + coords=TBLR(top=0, bottom=image.height, left=0, right=image.width), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + ] + + # Prepare input image for inference. + image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) # Load the model. spandrel_model_info = context.models.load(self.image_to_image_model) + # Run the model on each tile. + output_tiles: list[torch.Tensor] = [] + scale: int = 1 with spandrel_model_info as spandrel_model: assert isinstance(spandrel_model, SpandrelImageToImageModel) - # Prepare input image for inference. - image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) + # Scale the tiles for re-assembling the final image. + scale = spandrel_model.scale + scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles] + image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) - # Run inference. - image_tensor = spandrel_model.run(image_tensor) + for tile in tqdm(tiles, desc="Upscaling Tiles"): + output_tile = spandrel_model.run( + image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right] + ) + output_tiles.append(output_tile) + + # Merge tiles into output image. + np_output_tiles = [np.array(SpandrelImageToImageModel.tensor_to_pil(tile)) for tile in output_tiles] + _, channels, height, width = image_tensor.shape + np_out_image = np.zeros((height * scale, width * scale, channels), dtype=np.uint8) + merge_tiles_with_linear_blending( + dst_image=np_out_image, tiles=scaled_tiles, tile_images=np_output_tiles, blend_amount=min_overlap // 2 + ) # Convert the output tensor to a PIL image. - pil_image = SpandrelImageToImageModel.tensor_to_pil(image_tensor) + pil_image = Image.fromarray(np_out_image) image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto) diff --git a/invokeai/backend/spandrel_image_to_image_model.py b/invokeai/backend/spandrel_image_to_image_model.py index adb78d0d71..ccf02c57ac 100644 --- a/invokeai/backend/spandrel_image_to_image_model.py +++ b/invokeai/backend/spandrel_image_to_image_model.py @@ -126,6 +126,11 @@ class SpandrelImageToImageModel(RawModel): """The dtype of the underlying model.""" return self._spandrel_model.dtype + @property + def scale(self) -> int: + """The scale of the model (e.g. 1x, 2x, 4x, etc.).""" + return self._spandrel_model.scale + def calc_size(self) -> int: """Get size of the model in memory in bytes.""" # HACK(ryand): Fix this issue with circular imports. From d868d5d584a2f50efffb004da2a15bdfb50f0166 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 10 Jul 2024 12:25:00 -0400 Subject: [PATCH 2/4] Make SpandrelImageToImage tiling much faster. --- .../invocations/spandrel_image_to_image.py | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py index 1591f51bec..788a59f36b 100644 --- a/invokeai/app/invocations/spandrel_image_to_image.py +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -1,6 +1,4 @@ -import numpy as np import torch -from PIL import Image from tqdm import tqdm from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation @@ -16,7 +14,7 @@ from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel -from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.tiles import calc_tiles_min_overlap from invokeai.backend.tiles.utils import TBLR, Tile @@ -50,6 +48,29 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ), ) + def _merge_tiles(self, tiles: list[Tile], tile_tensors: list[torch.Tensor], out_tensor: torch.Tensor): + """A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we + split the overlap in half. No 'blending' is applied. + """ + # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to + # iterate over tiles left-to-right, top-to-bottom. + tiles_and_tensors = list(zip(tiles, tile_tensors, strict=True)) + tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.left) + tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.top) + + for tile, tile_tensor in tiles_and_tensors: + # We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge + # artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems + # unnecessary, but we may find a need in the future. + top_overlap = tile.overlap.top // 2 + left_overlap = tile.overlap.left // 2 + out_tensor[ + :, + :, + tile.coords.top + top_overlap : tile.coords.bottom, + tile.coords.left + left_overlap : tile.coords.right, + ] = tile_tensor[:, :, top_overlap:, left_overlap:] + @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to @@ -100,15 +121,19 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ) output_tiles.append(output_tile) - # Merge tiles into output image. - np_output_tiles = [np.array(SpandrelImageToImageModel.tensor_to_pil(tile)) for tile in output_tiles] - _, channels, height, width = image_tensor.shape - np_out_image = np.zeros((height * scale, width * scale, channels), dtype=np.uint8) - merge_tiles_with_linear_blending( - dst_image=np_out_image, tiles=scaled_tiles, tile_images=np_output_tiles, blend_amount=min_overlap // 2 + # TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue: + # - Keep the input tensor on the CPU. + # - Move each tile to the GPU as it is processed. + # - Move output tensors back to the CPU as they are produced, and merge them into the output tensor. + + # Merge the tiles to an output tensor. + batch_size, channels, height, width = image_tensor.shape + output_tensor = torch.zeros( + (batch_size, channels, height * scale, width * scale), dtype=image_tensor.dtype, device=image_tensor.device ) + self._merge_tiles(scaled_tiles, output_tiles, output_tensor) # Convert the output tensor to a PIL image. - pil_image = Image.fromarray(np_out_image) + pil_image = SpandrelImageToImageModel.tensor_to_pil(output_tensor) image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto) From d0d295599215fcb7eb87981a79ad10450595679b Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 10 Jul 2024 12:56:09 -0400 Subject: [PATCH 3/4] Reduce peak VRAM utilization of SpandrelImageToImageInvocation. --- .../invocations/spandrel_image_to_image.py | 82 +++++++++---------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py index 788a59f36b..650c9bb547 100644 --- a/invokeai/app/invocations/spandrel_image_to_image.py +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -1,4 +1,6 @@ +import numpy as np import torch +from PIL import Image from tqdm import tqdm from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation @@ -48,29 +50,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ), ) - def _merge_tiles(self, tiles: list[Tile], tile_tensors: list[torch.Tensor], out_tensor: torch.Tensor): - """A simple tile merging algorithm. tile_tensors are merged into out_tensor. When adjacent tiles overlap, we - split the overlap in half. No 'blending' is applied. - """ - # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to - # iterate over tiles left-to-right, top-to-bottom. - tiles_and_tensors = list(zip(tiles, tile_tensors, strict=True)) - tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.left) - tiles_and_tensors = sorted(tiles_and_tensors, key=lambda x: x[0].coords.top) - - for tile, tile_tensor in tiles_and_tensors: - # We only keep half of the overlap on the top and left side of the tile. We do this in case there are edge - # artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers it seems - # unnecessary, but we may find a need in the future. - top_overlap = tile.overlap.top // 2 - left_overlap = tile.overlap.left // 2 - out_tensor[ - :, - :, - tile.coords.top + top_overlap : tile.coords.bottom, - tile.coords.left + left_overlap : tile.coords.right, - ] = tile_tensor[:, :, top_overlap:, left_overlap:] - @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to @@ -97,6 +76,11 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ) ] + # Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate + # over tiles left-to-right, top-to-bottom. + tiles = sorted(tiles, key=lambda x: x.coords.left) + tiles = sorted(tiles, key=lambda x: x.coords.top) + # Prepare input image for inference. image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) @@ -104,8 +88,6 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): spandrel_model_info = context.models.load(self.image_to_image_model) # Run the model on each tile. - output_tiles: list[torch.Tensor] = [] - scale: int = 1 with spandrel_model_info as spandrel_model: assert isinstance(spandrel_model, SpandrelImageToImageModel) @@ -113,27 +95,45 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): scale = spandrel_model.scale scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles] + # Prepare the output tensor. + _, channels, height, width = image_tensor.shape + output_tensor = torch.zeros( + (height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu") + ) + image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) - for tile in tqdm(tiles, desc="Upscaling Tiles"): - output_tile = spandrel_model.run( - image_tensor[:, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right] - ) - output_tiles.append(output_tile) + for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"): + # Extract the current tile from the input tensor. + input_tile = image_tensor[ + :, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right + ].to(device=spandrel_model.device, dtype=spandrel_model.dtype) - # TODO(ryand): There are opportunities to reduce peak VRAM utilization here if it becomes an issue: - # - Keep the input tensor on the CPU. - # - Move each tile to the GPU as it is processed. - # - Move output tensors back to the CPU as they are produced, and merge them into the output tensor. + # Run the model on the tile. + output_tile = spandrel_model.run(input_tile) - # Merge the tiles to an output tensor. - batch_size, channels, height, width = image_tensor.shape - output_tensor = torch.zeros( - (batch_size, channels, height * scale, width * scale), dtype=image_tensor.dtype, device=image_tensor.device - ) - self._merge_tiles(scaled_tiles, output_tiles, output_tensor) + # Convert the output tile into the output tensor's format. + # (N, C, H, W) -> (C, H, W) + output_tile = output_tile.squeeze(0) + # (C, H, W) -> (H, W, C) + output_tile = output_tile.permute(1, 2, 0) + output_tile = output_tile.clamp(0, 1) + output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu")) + + # Merge the output tile into the output tensor. + # We only keep half of the overlap on the top and left side of the tile. We do this in case there are + # edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers + # it seems unnecessary, but we may find a need in the future. + top_overlap = scaled_tile.overlap.top // 2 + left_overlap = scaled_tile.overlap.left // 2 + output_tensor[ + scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom, + scaled_tile.coords.left + left_overlap : scaled_tile.coords.right, + :, + ] = output_tile[top_overlap:, left_overlap:, :] # Convert the output tensor to a PIL image. - pil_image = SpandrelImageToImageModel.tensor_to_pil(output_tensor) + np_image = output_tensor.detach().numpy().astype(np.uint8) + pil_image = Image.fromarray(np_image) image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto) From 0428ce73a9c10a1fe449e361b6957ecf4dc5c71d Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 11 Jul 2024 15:42:33 -0400 Subject: [PATCH 4/4] Add early cancellation to SpandrelImageToImageInvocation. --- invokeai/app/invocations/spandrel_image_to_image.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py index 650c9bb547..bbe31af644 100644 --- a/invokeai/app/invocations/spandrel_image_to_image.py +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -14,6 +14,7 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.tiles.tiles import calc_tiles_min_overlap @@ -104,6 +105,10 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"): + # Exit early if the invocation has been canceled. + if context.util.is_canceled(): + raise CanceledException + # Extract the current tile from the input tensor. input_tile = image_tensor[ :, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right