From 674d9796d079f981a34569756c97d68ba6507381 Mon Sep 17 00:00:00 2001 From: skunkworxdark Date: Tue, 5 Dec 2023 21:03:16 +0000 Subject: [PATCH] First check-in of new tile nodes - calc_tiles_even_split - calc_tiles_min_overlap - merge_tiles_with_seam_blending Update MergeTilesToImageInvocation with seam blending --- invokeai/app/invocations/tiles.py | 114 +++++++++++++- invokeai/backend/tiles/tiles.py | 245 +++++++++++++++++++++++++++--- invokeai/backend/tiles/utils.py | 100 ++++++++++++ 3 files changed, 435 insertions(+), 24 deletions(-) diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index 3055c1baae..609b539610 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np from PIL import Image from pydantic import BaseModel @@ -5,6 +7,7 @@ from pydantic import BaseModel from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Input, InputField, InvocationContext, OutputField, @@ -15,7 +18,13 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.tiles import ( + calc_tiles_even_split, + calc_tiles_min_overlap, + calc_tiles_with_overlap, + merge_tiles_with_linear_blending, + merge_tiles_with_seam_blending, +) from invokeai.backend.tiles.utils import Tile @@ -56,6 +65,86 @@ class CalculateImageTilesInvocation(BaseInvocation): return CalculateImageTilesOutput(tiles=tiles) +@invocation( + "calculate_image_tiles_Even_Split", + title="Calculate Image Tiles Even Split", + tags=["tiles"], + category="tiles", + version="1.0.0", +) +class CalculateImageTilesEvenSplitInvocation(BaseInvocation): + """Calculate the coordinates and overlaps of tiles that cover a target image shape.""" + + image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.") + image_height: int = InputField( + ge=1, default=1024, description="The image height, in pixels, to calculate tiles for." + ) + num_tiles_x: int = InputField( + default=2, + ge=1, + description="Number of tiles to divide image into on the x axis", + ) + num_tiles_y: int = InputField( + default=2, + ge=1, + description="Number of tiles to divide image into on the y axis", + ) + overlap: float = InputField( + default=0.25, + ge=0, + lt=1, + description="Overlap amount of tile size (0-1)", + ) + + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + tiles = calc_tiles_even_split( + image_height=self.image_height, + image_width=self.image_width, + num_tiles_x=self.num_tiles_x, + num_tiles_y=self.num_tiles_y, + overlap=self.overlap, + ) + return CalculateImageTilesOutput(tiles=tiles) + + +@invocation( + "calculate_image_tiles_min_overlap", + title="Calculate Image Tiles Minimum Overlap", + tags=["tiles"], + category="tiles", + version="1.0.0", +) +class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): + """Calculate the coordinates and overlaps of tiles that cover a target image shape.""" + + image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.") + image_height: int = InputField( + ge=1, default=1024, description="The image height, in pixels, to calculate tiles for." + ) + tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.") + tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") + min_overlap: int = InputField( + default=128, + ge=0, + description="minimum tile overlap size (must be a multiple of 8)", + ) + round_to_8: bool = InputField( + default=False, + description="Round outputs down to the nearest 8 (for pulling from a large noise field)", + ) + + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + tiles = calc_tiles_min_overlap( + image_height=self.image_height, + image_width=self.image_width, + tile_height=self.tile_height, + tile_width=self.tile_width, + min_overlap=self.min_overlap, + round_to_8=self.round_to_8, + ) + return CalculateImageTilesOutput(tiles=tiles) + + @invocation_output("tile_to_properties_output") class TileToPropertiesOutput(BaseInvocationOutput): coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.") @@ -122,13 +211,22 @@ class PairTileImageInvocation(BaseInvocation): ) -@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.0.0") +BLEND_MODES = Literal["Linear", "Seam"] + + +@invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0") class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow): """Merge multiple tile images into a single image.""" # Inputs tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.") + blend_mode: BLEND_MODES = InputField( + default="Seam", + description="blending type Linear or Seam", + input=Input.Direct, + ) blend_amount: int = InputField( + default=32, ge=0, description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) @@ -158,10 +256,16 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata, WithWorkflow): channels = tile_np_images[0].shape[-1] dtype = tile_np_images[0].dtype np_image = np.zeros(shape=(height, width, channels), dtype=dtype) + if self.blend_mode == "Linear": + merge_tiles_with_linear_blending( + dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount + ) + else: + merge_tiles_with_seam_blending( + dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount + ) - merge_tiles_with_linear_blending( - dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount - ) + # Convert into a PIL image and save pil_image = Image.fromarray(np_image) image_dto = context.services.images.create( diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py index 3a678d825e..33c5f5c445 100644 --- a/invokeai/backend/tiles/tiles.py +++ b/invokeai/backend/tiles/tiles.py @@ -3,7 +3,40 @@ from typing import Union import numpy as np -from invokeai.backend.tiles.utils import TBLR, Tile, paste +from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend + + +def calc_overlap(tiles: list[Tile], num_tiles_x, num_tiles_y) -> list[Tile]: + """Calculate and update the overlap of a list of tiles. + + Args: + tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`. + num_tiles_x: the number of tiles on the x axis. + num_tiles_y: the number of tiles on the y axis. + """ + def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]: + if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x: + return None + return tiles[idx_y * num_tiles_x + idx_x] + + for tile_idx_y in range(num_tiles_y): + for tile_idx_x in range(num_tiles_x): + cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x) + top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x) + left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1) + + assert cur_tile is not None + + # Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap. + if top_neighbor_tile is not None: + cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top) + top_neighbor_tile.overlap.bottom = cur_tile.overlap.top + + # Update cur_tile left-overlap and corresponding left-neighbor right-overlap. + if left_neighbor_tile is not None: + cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left) + left_neighbor_tile.overlap.right = cur_tile.overlap.left + return tiles def calc_tiles_with_overlap( @@ -63,31 +96,117 @@ def calc_tiles_with_overlap( tiles.append(tile) - def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]: - if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x: - return None - return tiles[idx_y * num_tiles_x + idx_x] + return calc_overlap(tiles, num_tiles_x, num_tiles_y) - # Iterate over tiles again and calculate overlaps. + +def calc_tiles_even_split( + image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap: float = 0 +) -> list[Tile]: + """Calculate the tile coordinates for a given image shape with the number of tiles requested. + + Args: + image_height (int): The image height in px. + image_width (int): The image width in px. + num_x_tiles (int): The number of tile to split the image into on the X-axis. + num_y_tiles (int): The number of tile to split the image into on the Y-axis. + overlap (int, optional): The target overlap amount of the tiles size. Defaults to 0. + + Returns: + list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom. + """ + + # Ensure tile size is divisible by 8 + if image_width % 8 != 0 or image_height % 8 != 0: + raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by 8") + + # Calculate the overlap size based on the percentage and adjust it to be divisible by 8 (rounding up) + overlap_x = 8 * math.ceil(int((image_width / num_tiles_x) * overlap) / 8) + overlap_y = 8 * math.ceil(int((image_height / num_tiles_y) * overlap) / 8) + + # Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down) + tile_size_x = 8 * math.floor(((image_width + overlap_x * (num_tiles_x - 1)) // num_tiles_x) / 8) + tile_size_y = 8 * math.floor(((image_height + overlap_y * (num_tiles_y - 1)) // num_tiles_y) / 8) + + # tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column. + tiles: list[Tile] = [] + + # Calculate tile coordinates. (Ignore overlap values for now.) for tile_idx_y in range(num_tiles_y): + # Calculate the top and bottom of the row + top = tile_idx_y * (tile_size_y - overlap_y) + bottom = min(top + tile_size_y, image_height) + # For the last row adjust bottom to be the height of the image + if tile_idx_y == num_tiles_y - 1: + bottom = image_height + for tile_idx_x in range(num_tiles_x): - cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x) - top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x) - left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1) + # Calculate the left & right coordinate of each tile + left = tile_idx_x * (tile_size_x - overlap_x) + right = min(left + tile_size_x, image_width) + # For the last tile in the row adjust right to be the width of the image + if tile_idx_x == num_tiles_x - 1: + right = image_width - assert cur_tile is not None + tile = Tile( + coords=TBLR(top=top, bottom=bottom, left=left, right=right), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) - # Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap. - if top_neighbor_tile is not None: - cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top) - top_neighbor_tile.overlap.bottom = cur_tile.overlap.top + tiles.append(tile) - # Update cur_tile left-overlap and corresponding left-neighbor right-overlap. - if left_neighbor_tile is not None: - cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left) - left_neighbor_tile.overlap.right = cur_tile.overlap.left + return calc_overlap(tiles, num_tiles_x, num_tiles_y) - return tiles + +def calc_tiles_min_overlap( + image_height: int, image_width: int, tile_height: int, tile_width: int, min_overlap: int, round_to_8: bool +) -> list[Tile]: + """Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps. + + Args: + image_height (int): The image height in px. + image_width (int): The image width in px. + tile_height (int): The tile height in px. All tiles will have this height. + tile_width (int): The tile width in px. All tiles will have this width. + min_overlap (int): The target minimum overlap between adjacent tiles. If the tiles do not evenly cover the image + shape, then the overlap will be spread between the tiles. + + Returns: + list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom. + """ + assert image_height >= tile_height + assert image_width >= tile_width + assert min_overlap < tile_height + assert min_overlap < tile_width + + num_tiles_x = math.ceil((image_width - min_overlap) / (tile_width - min_overlap)) if tile_width < image_width else 1 + num_tiles_y = ( + math.ceil((image_height - min_overlap) / (tile_height - min_overlap)) if tile_height < image_height else 1 + ) + + # tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column. + tiles: list[Tile] = [] + + # Calculate tile coordinates. (Ignore overlap values for now.) + for tile_idx_y in range(num_tiles_y): + top = (tile_idx_y * (image_height - tile_height)) // (num_tiles_y - 1) if num_tiles_y > 1 else 0 + if round_to_8: + top = 8 * (top // 8) + bottom = top + tile_height + + for tile_idx_x in range(num_tiles_x): + left = (tile_idx_x * (image_width - tile_width)) // (num_tiles_x - 1) if num_tiles_x > 1 else 0 + if round_to_8: + left = 8 * (left // 8) + right = left + tile_width + + tile = Tile( + coords=TBLR(top=top, bottom=bottom, left=left, right=right), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + + tiles.append(tile) + + return calc_overlap(tiles, num_tiles_x, num_tiles_y) def merge_tiles_with_linear_blending( @@ -199,3 +318,91 @@ def merge_tiles_with_linear_blending( ), mask=mask, ) + + +def merge_tiles_with_seam_blending( + dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int +): + """Merge a set of image tiles into `dst_image` with seam blending between the tiles. + + We expect every tile edge to either: + 1) have an overlap of 0, because it is aligned with the image edge, or + 2) have an overlap >= blend_amount. + If neither of these conditions are satisfied, we raise an exception. + + The seam blending is centered on a seam of least energy of the overlap between adjacent tiles. + + Args: + dst_image (np.ndarray): The destination image. Shape: (H, W, C). + tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`. + tile_images (list[np.ndarray]): The tile images to merge into `dst_image`. + blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles. + """ + # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to + # iterate over tiles left-to-right, top-to-bottom. + tiles_and_images = list(zip(tiles, tile_images, strict=True)) + tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left) + tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top) + + # Organize tiles into rows. + tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = [] + cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = [] + first_tile_in_cur_row, _ = tiles_and_images[0] + for tile_and_image in tiles_and_images: + tile, _ = tile_and_image + if not ( + tile.coords.top == first_tile_in_cur_row.coords.top + and tile.coords.bottom == first_tile_in_cur_row.coords.bottom + ): + # Store the previous row, and start a new one. + tile_and_image_rows.append(cur_tile_and_image_row) + cur_tile_and_image_row = [] + first_tile_in_cur_row, _ = tile_and_image + + cur_tile_and_image_row.append(tile_and_image) + tile_and_image_rows.append(cur_tile_and_image_row) + + for tile_and_image_row in tile_and_image_rows: + first_tile_in_row, _ = tile_and_image_row[0] + row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top + row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype) + + # Blend the tiles in the row horizontally. + for tile, tile_image in tile_and_image_row: + # We expect the tiles to be ordered left-to-right. + # For each tile: + # - extract the overlap regions and pass to seam_blend() + # - apply blended region to the row_image + # - apply the un-blended region to the row_image + tile_height, tile_width, _ = tile_image.shape + overlap_size = tile.overlap.left + # Left blending: + if overlap_size > 0: + assert overlap_size >= blend_amount + + overlap_coord_right = tile.coords.left + overlap_size + src_overlap = row_image[:, tile.coords.left : overlap_coord_right] + dst_overlap = tile_image[:, :overlap_size] + blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=False) + row_image[:, tile.coords.left : overlap_coord_right] = blended_overlap + row_image[:, overlap_coord_right : tile.coords.right] = tile_image[:, overlap_size:] + else: + # no overlap just paste the tile + row_image[:, tile.coords.left : tile.coords.right] = tile_image + + # Blend the row into the dst_image + # We assume that the entire row has the same vertical overlaps as the first_tile_in_row. + # Rows are processed in the same way as tiles (extract overlap, blend, apply) + row_overlap_size = first_tile_in_row.overlap.top + if row_overlap_size > 0: + assert row_overlap_size >= blend_amount + + overlap_coords_bottom = first_tile_in_row.coords.top + row_overlap_size + src_overlap = dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] + dst_overlap = row_image[:row_overlap_size, :] + blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=True) + dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] = blended_overlap + dst_image[overlap_coords_bottom : first_tile_in_row.coords.bottom, :] = row_image[row_overlap_size:, :] + else: + # no overlap just paste the row + dst_image[first_tile_in_row.coords.top:first_tile_in_row.coords.bottom, :] = row_image diff --git a/invokeai/backend/tiles/utils.py b/invokeai/backend/tiles/utils.py index 4ad40ffa35..9df63a601e 100644 --- a/invokeai/backend/tiles/utils.py +++ b/invokeai/backend/tiles/utils.py @@ -1,5 +1,7 @@ +import math from typing import Optional +import cv2 import numpy as np from pydantic import BaseModel, Field @@ -45,3 +47,101 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona mask = np.expand_dims(mask, -1) dst_image_box = dst_image[box.top : box.bottom, box.left : box.right] dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask) + + +def seam_blend(ia1: np.ndarray, ia2: np.ndarray, blend_amount: int, x_seam: bool,) -> np.ndarray: + """Blend two overlapping tile sections using a seams to find a path. + + It is assumed that input images will be RGB np arrays and are the same size. + + Args: + ia1 (torch.Tensor): Image array 1 Shape: (H, W, C). + ia2 (torch.Tensor): Image array 2 Shape: (H, W, C). + x_seam (bool): If the images should be blended on the x axis or not. + blend_amount (int): The size of the blur to use on the seam. Half of this value will be used to avoid the edges of the image. + """ + assert ia1.shape == ia2.shape + assert ia2.size == ia2.size + + def shift(arr, num, fill_value=255.0): + result = np.full_like(arr, fill_value) + if num > 0: + result[num:] = arr[:-num] + elif num < 0: + result[:num] = arr[-num:] + else: + result[:] = arr + return result + + # Assume RGB and convert to grey + iag1 = np.dot(ia1, [0.2989, 0.5870, 0.1140]) + iag2 = np.dot(ia2, [0.2989, 0.5870, 0.1140]) + + # Calc Difference between the images + ia = iag2 - iag1 + + # If the seam is on the X-axis rotate the array so we can treat it like a vertical seam + if x_seam: + ia = np.rot90(ia, 1) + + # Calc max and min X & Y limits + # gutter is used to avoid the blur hitting the edge of the image + gutter = math.ceil(blend_amount / 2) if blend_amount > 0 else 0 + max_y, max_x = ia.shape + max_x -= gutter + min_x = gutter + + # Calc the energy in the difference + energy = np.abs(np.gradient(ia, axis=0)) + np.abs(np.gradient(ia, axis=1)) + + #Find the starting position of the seam + res = np.copy(energy) + for y in range(1, max_y): + row = res[y, :] + rowl = shift(row, -1) + rowr = shift(row, 1) + res[y, :] = res[y - 1, :] + np.min([row, rowl, rowr], axis=0) + + # create an array max_y long + lowest_energy_line = np.empty([max_y], dtype="uint16") + lowest_energy_line[max_y - 1] = np.argmin(res[max_y - 1, min_x : max_x - 1]) + + #Calc the path of the seam + for ypos in range(max_y - 2, -1, -1): + lowest_pos = lowest_energy_line[ypos + 1] + lpos = lowest_pos - 1 + rpos = lowest_pos + 1 + lpos = np.clip(lpos, min_x, max_x - 1) + rpos = np.clip(rpos, min_x, max_x - 1) + lowest_energy_line[ypos] = np.argmin(energy[ypos, lpos : rpos + 1]) + lpos + + # Draw the mask + mask = np.zeros_like(ia) + for ypos in range(0, max_y): + to_fill = lowest_energy_line[ypos] + mask[ypos, :to_fill] = 1 + + # If the seam is on the X-axis rotate the array back + if x_seam: + mask = np.rot90(mask, 3) + + # blur the seam mask if required + if blend_amount > 0: + mask = cv2.blur(mask, (blend_amount, blend_amount)) + + # for visual debugging + #from PIL import Image + #m_image = Image.fromarray((mask * 255.0).astype("uint8")) + + # copy ia2 over ia1 while applying the seam mask + mask = np.expand_dims(mask, -1) + blended_image = ia1 * mask + ia2 * (1.0 - mask) + + # for visual debugging + #i1 = Image.fromarray(ia1.astype("uint8")) + #i2 = Image.fromarray(ia2.astype("uint8")) + #b_image = Image.fromarray(blended_image.astype("uint8")) + #print(f"{ia1.shape}, {ia2.shape}, {mask.shape}, {blended_image.shape}") + #print(f"{i1.size}, {i2.size}, {m_image.size}, {b_image.size}") + + return blended_image