from typing import Callable import numpy as np import torch from PIL import Image from tqdm import tqdm from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.fields import ( FieldDescriptions, ImageField, InputField, UIType, WithBoard, WithMetadata, ) from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.session_processor.session_processor_common import CanceledException from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.spandrel_image_to_image_model import SpandrelImageToImageModel from invokeai.backend.tiles.tiles import calc_tiles_min_overlap from invokeai.backend.tiles.utils import TBLR, Tile @invocation("spandrel_image_to_image", title="Image-to-Image", tags=["upscale"], category="upscale", version="1.3.0") class SpandrelImageToImageInvocation(BaseInvocation, WithMetadata, WithBoard): """Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel).""" image: ImageField = InputField(description="The input image") image_to_image_model: ModelIdentifierField = InputField( title="Image-to-Image Model", description=FieldDescriptions.spandrel_image_to_image_model, ui_type=UIType.SpandrelImageToImageModel, ) tile_size: int = InputField( default=512, description="The tile size for tiled image-to-image. Set to 0 to disable tiling." ) @classmethod def scale_tile(cls, tile: Tile, scale: int) -> Tile: return Tile( coords=TBLR( top=tile.coords.top * scale, bottom=tile.coords.bottom * scale, left=tile.coords.left * scale, right=tile.coords.right * scale, ), overlap=TBLR( top=tile.overlap.top * scale, bottom=tile.overlap.bottom * scale, left=tile.overlap.left * scale, right=tile.overlap.right * scale, ), ) @classmethod def upscale_image( cls, image: Image.Image, tile_size: int, spandrel_model: SpandrelImageToImageModel, is_canceled: Callable[[], bool], ) -> Image.Image: # Compute the image tiles. if tile_size > 0: min_overlap = 20 tiles = calc_tiles_min_overlap( image_height=image.height, image_width=image.width, tile_height=tile_size, tile_width=tile_size, min_overlap=min_overlap, ) else: # No tiling. Generate a single tile that covers the entire image. min_overlap = 0 tiles = [ Tile( coords=TBLR(top=0, bottom=image.height, left=0, right=image.width), overlap=TBLR(top=0, bottom=0, left=0, right=0), ) ] # Sort tiles first by left x coordinate, then by top y coordinate. During tile processing, we want to iterate # over tiles left-to-right, top-to-bottom. tiles = sorted(tiles, key=lambda x: x.coords.left) tiles = sorted(tiles, key=lambda x: x.coords.top) # Prepare input image for inference. image_tensor = SpandrelImageToImageModel.pil_to_tensor(image) # Scale the tiles for re-assembling the final image. scale = spandrel_model.scale scaled_tiles = [cls.scale_tile(tile, scale=scale) for tile in tiles] # Prepare the output tensor. _, channels, height, width = image_tensor.shape output_tensor = torch.zeros( (height * scale, width * scale, channels), dtype=torch.uint8, device=torch.device("cpu") ) image_tensor = image_tensor.to(device=spandrel_model.device, dtype=spandrel_model.dtype) # Run the model on each tile. for tile, scaled_tile in tqdm(list(zip(tiles, scaled_tiles, strict=True)), desc="Upscaling Tiles"): # Exit early if the invocation has been canceled. if is_canceled(): raise CanceledException # Extract the current tile from the input tensor. input_tile = image_tensor[ :, :, tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right ].to(device=spandrel_model.device, dtype=spandrel_model.dtype) # Run the model on the tile. output_tile = spandrel_model.run(input_tile) # Convert the output tile into the output tensor's format. # (N, C, H, W) -> (C, H, W) output_tile = output_tile.squeeze(0) # (C, H, W) -> (H, W, C) output_tile = output_tile.permute(1, 2, 0) output_tile = output_tile.clamp(0, 1) output_tile = (output_tile * 255).to(dtype=torch.uint8, device=torch.device("cpu")) # Merge the output tile into the output tensor. # We only keep half of the overlap on the top and left side of the tile. We do this in case there are # edge artifacts. We don't bother with any 'blending' in the current implementation - for most upscalers # it seems unnecessary, but we may find a need in the future. top_overlap = scaled_tile.overlap.top // 2 left_overlap = scaled_tile.overlap.left // 2 output_tensor[ scaled_tile.coords.top + top_overlap : scaled_tile.coords.bottom, scaled_tile.coords.left + left_overlap : scaled_tile.coords.right, :, ] = output_tile[top_overlap:, left_overlap:, :] # Convert the output tensor to a PIL image. np_image = output_tensor.detach().numpy().astype(np.uint8) pil_image = Image.fromarray(np_image) return pil_image @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to # revisit this. image = context.images.get_pil(self.image.image_name, mode="RGB") # Load the model. spandrel_model_info = context.models.load(self.image_to_image_model) # Do the upscaling. with spandrel_model_info as spandrel_model: assert isinstance(spandrel_model, SpandrelImageToImageModel) # Upscale the image pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled) image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto) @invocation( "spandrel_image_to_image_autoscale", title="Image-to-Image (Autoscale)", tags=["upscale"], category="upscale", version="1.0.0", ) class SpandrelImageToImageAutoscaleInvocation(SpandrelImageToImageInvocation): """Run any spandrel image-to-image model (https://github.com/chaiNNer-org/spandrel) until the target scale is reached.""" scale: float = InputField( default=4.0, gt=0.0, le=16.0, description="The final scale of the output image. If the model does not upscale the image, this will be ignored.", ) fit_to_multiple_of_8: bool = InputField( default=False, description="If true, the output image will be resized to the nearest multiple of 8 in both dimensions.", ) @torch.inference_mode() def invoke(self, context: InvocationContext) -> ImageOutput: # Images are converted to RGB, because most models don't support an alpha channel. In the future, we may want to # revisit this. image = context.images.get_pil(self.image.image_name, mode="RGB") # Load the model. spandrel_model_info = context.models.load(self.image_to_image_model) # The target size of the image, determined by the provided scale. We'll run the upscaler until we hit this size. # Later, we may mutate this value if the model doesn't upscale the image or if the user requested a multiple of 8. target_width = int(image.width * self.scale) target_height = int(image.height * self.scale) # Do the upscaling. with spandrel_model_info as spandrel_model: assert isinstance(spandrel_model, SpandrelImageToImageModel) # First pass of upscaling. Note: `pil_image` will be mutated. pil_image = self.upscale_image(image, self.tile_size, spandrel_model, context.util.is_canceled) # Some models don't upscale the image, but we have no way to know this in advance. We'll check if the model # upscaled the image and run the loop below if it did. We'll require the model to upscale both dimensions # to be considered an upscale model. is_upscale_model = pil_image.width > image.width and pil_image.height > image.height if is_upscale_model: # This is an upscale model, so we should keep upscaling until we reach the target size. iterations = 1 while pil_image.width < target_width or pil_image.height < target_height: pil_image = self.upscale_image(pil_image, self.tile_size, spandrel_model, context.util.is_canceled) iterations += 1 # Sanity check to prevent excessive or infinite loops. All known upscaling models are at least 2x. # Our max scale is 16x, so with a 2x model, we should never exceed 16x == 2^4 -> 4 iterations. # We'll allow one extra iteration "just in case" and bail at 5 upscaling iterations. In practice, # we should never reach this limit. if iterations >= 5: context.logger.warning( "Upscale loop reached maximum iteration count of 5, stopping upscaling early." ) break else: # This model doesn't upscale the image. We should ignore the scale parameter, modifying the output size # to be the same as the processed image size. # The output size is now the size of the processed image. target_width = pil_image.width target_height = pil_image.height # Warn the user if they requested a scale greater than 1. if self.scale > 1: context.logger.warning( "Model does not increase the size of the image, but a greater scale than 1 was requested. Image will not be scaled." ) # We may need to resize the image to a multiple of 8. Use floor division to ensure we don't scale the image up # in the final resize if self.fit_to_multiple_of_8: target_width = int(target_width // 8 * 8) target_height = int(target_height // 8 * 8) # Final resize. Per PIL documentation, Lanczos provides the best quality for both upscale and downscale. # See: https://pillow.readthedocs.io/en/stable/handbook/concepts.html#filters-comparison-table pil_image = pil_image.resize((target_width, target_height), resample=Image.Resampling.LANCZOS) image_dto = context.images.save(image=pil_image) return ImageOutput.build(image_dto)