From a2ef5d56ee32e368c37925384648adcc39a9b763 Mon Sep 17 00:00:00 2001 From: psychedelicious <4822129+psychedelicious@users.noreply.github.com> Date: Tue, 23 Jul 2024 05:56:45 +1000 Subject: [PATCH] feat(nodes): split out spandrel node upscale logic into utils --- .../invocations/spandrel_image_to_image.py | 130 ++++++++++-------- 1 file changed, 73 insertions(+), 57 deletions(-) diff --git a/invokeai/app/invocations/spandrel_image_to_image.py b/invokeai/app/invocations/spandrel_image_to_image.py index bbe31af644..bf525d2a28 100644 --- a/invokeai/app/invocations/spandrel_image_to_image.py +++ b/invokeai/app/invocations/spandrel_image_to_image.py @@ -1,3 +1,5 @@ +from typing import Callable + import numpy as np import torch from PIL import Image @@ -35,7 +37,8 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling." ) - def _scale_tile(self, tile: Tile, scale: int) -> Tile: + @classmethod + def scale_tile(cls, tile: Tile, scale: int) -> Tile: return Tile( coords=TBLR( top=tile.coords.top * scale, @@ -51,20 +54,22 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): ), ) - @torch.inference_mode() - def invoke(self, context: InvocationContext) -> ImageOutput: - # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to - # revisit this. - image = context.images.get_pil(self.image.image_name, mode="RGB") - + @classmethod + def upscale_image( + cls, + image: Image.Image, + tile_size: int, + spandrel_model: SpandrelImageToImageModel, + is_canceled: Callable[[], bool], + ) -> Image.Image: # Compute the image tiles. - if self.tile_size > 0: + if tile_size > 0: min_overlap = 20 tiles = calc_tiles_min_overlap( image_height=image.height, image_width=image.width, - tile_height=self.tile_size, - tile_width=self.tile_size, + tile_height=tile_size, + tile_width=tile_size, min_overlap=min_overlap, ) else: @@ -85,6 +90,63 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): # Prepare input image for inference. image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) + # Scale the tiles for re-assembling the final image. + scale = spandrel_model.scale + scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles] + + # Prepare the output tensor. + _, channels, height, width = image_tensor.shape + output_tensor = torch.zeros( + (height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu") + ) + + image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) + + for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"): + # Exit early if the invocation has been canceled. + if is_canceled(): + raise CanceledException + + # Extract the current tile from the input tensor. + input_tile = image_tensor[ + :, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right + ].to(device=spandrel_model.device, dtype=spandrel_model.dtype) + + # Run the model on the tile. + output_tile = spandrel_model.run(input_tile) + + # Convert the output tile into the output tensor's format. + # (N, C, H, W) -> (C, H, W) + output_tile = output_tile.squeeze(0) + # (C, H, W) -> (H, W, C) + output_tile = output_tile.permute(1, 2, 0) + output_tile = output_tile.clamp(0, 1) + output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu")) + + # Merge the output tile into the output tensor. + # We only keep half of the overlap on the top and left side of the tile. We do this in case there are + # edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers + # it seems unnecessary, but we may find a need in the future. + top_overlap = scaled_tile.overlap.top // 2 + left_overlap = scaled_tile.overlap.left // 2 + output_tensor[ + scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom, + scaled_tile.coords.left + left_overlap : scaled_tile.coords.right, + :, + ] = output_tile[top_overlap:, left_overlap:, :] + + # Convert the output tensor to a PIL image. + np_image = output_tensor.detach().numpy().astype(np.uint8) + pil_image = Image.fromarray(np_image) + + return pil_image + + @torch.inference_mode() + def invoke(self, context: InvocationContext) -> ImageOutput: + # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to + # revisit this. + image = context.images.get_pil(self.image.image_name, mode="RGB") + # Load the model. spandrel_model_info = context.models.load(self.image_to_image_model) @@ -92,53 +154,7 @@ class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): with spandrel_model_info as spandrel_model: assert isinstance(spandrel_model, SpandrelImageToImageModel) - # Scale the tiles for re-assembling the final image. - scale = spandrel_model.scale - scaled_tiles = [self._scale_tile(tile, scale=scale) for tile in tiles] + pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled) - # Prepare the output tensor. - _, channels, height, width = image_tensor.shape - output_tensor = torch.zeros( - (height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu") - ) - - image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) - - for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"): - # Exit early if the invocation has been canceled. - if context.util.is_canceled(): - raise CanceledException - - # Extract the current tile from the input tensor. - input_tile = image_tensor[ - :, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right - ].to(device=spandrel_model.device, dtype=spandrel_model.dtype) - - # Run the model on the tile. - output_tile = spandrel_model.run(input_tile) - - # Convert the output tile into the output tensor's format. - # (N, C, H, W) -> (C, H, W) - output_tile = output_tile.squeeze(0) - # (C, H, W) -> (H, W, C) - output_tile = output_tile.permute(1, 2, 0) - output_tile = output_tile.clamp(0, 1) - output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu")) - - # Merge the output tile into the output tensor. - # We only keep half of the overlap on the top and left side of the tile. We do this in case there are - # edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers - # it seems unnecessary, but we may find a need in the future. - top_overlap = scaled_tile.overlap.top // 2 - left_overlap = scaled_tile.overlap.left // 2 - output_tensor[ - scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom, - scaled_tile.coords.left + left_overlap : scaled_tile.coords.right, - :, - ] = output_tile[top_overlap:, left_overlap:, :] - - # Convert the output tensor to a PIL image. - np_image = output_tensor.detach().numpy().astype(np.uint8) - pil_image = Image.fromarray(np_image) image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto)