diff --git a/invokeai/app/invocations/tiles.py b/invokeai/app/invocations/tiles.py index e59a0530ee..e368976b4b 100644 --- a/invokeai/app/invocations/tiles.py +++ b/invokeai/app/invocations/tiles.py @@ -1,3 +1,5 @@ +from typing import Literal + import numpy as np from PIL import Image from pydantic import BaseModel @@ -5,6 +7,7 @@ from pydantic import BaseModel from invokeai.app.invocations.baseinvocation import ( BaseInvocation, BaseInvocationOutput, + Input, InputField, InvocationContext, OutputField, @@ -14,7 +17,13 @@ from invokeai.app.invocations.baseinvocation import ( ) from invokeai.app.invocations.primitives import ImageField, ImageOutput from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin -from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.tiles import ( + calc_tiles_even_split, + calc_tiles_min_overlap, + calc_tiles_with_overlap, + merge_tiles_with_linear_blending, + merge_tiles_with_seam_blending, +) from invokeai.backend.tiles.utils import Tile @@ -55,6 +64,77 @@ class CalculateImageTilesInvocation(BaseInvocation): return CalculateImageTilesOutput(tiles=tiles) +@invocation( + "calculate_image_tiles_even_split", + title="Calculate Image Tiles Even Split", + tags=["tiles"], + category="tiles", + version="1.0.0", +) +class CalculateImageTilesEvenSplitInvocation(BaseInvocation): + """Calculate the coordinates and overlaps of tiles that cover a target image shape.""" + + image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.") + image_height: int = InputField( + ge=1, default=1024, description="The image height, in pixels, to calculate tiles for." + ) + num_tiles_x: int = InputField( + default=2, + ge=1, + description="Number of tiles to divide image into on the x axis", + ) + num_tiles_y: int = InputField( + default=2, + ge=1, + description="Number of tiles to divide image into on the y axis", + ) + overlap_fraction: float = InputField( + default=0.25, + ge=0, + lt=1, + description="Overlap between adjacent tiles as a fraction of the tile's dimensions (0-1)", + ) + + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + tiles = calc_tiles_even_split( + image_height=self.image_height, + image_width=self.image_width, + num_tiles_x=self.num_tiles_x, + num_tiles_y=self.num_tiles_y, + overlap_fraction=self.overlap_fraction, + ) + return CalculateImageTilesOutput(tiles=tiles) + + +@invocation( + "calculate_image_tiles_min_overlap", + title="Calculate Image Tiles Minimum Overlap", + tags=["tiles"], + category="tiles", + version="1.0.0", +) +class CalculateImageTilesMinimumOverlapInvocation(BaseInvocation): + """Calculate the coordinates and overlaps of tiles that cover a target image shape.""" + + image_width: int = InputField(ge=1, default=1024, description="The image width, in pixels, to calculate tiles for.") + image_height: int = InputField( + ge=1, default=1024, description="The image height, in pixels, to calculate tiles for." + ) + tile_width: int = InputField(ge=1, default=576, description="The tile width, in pixels.") + tile_height: int = InputField(ge=1, default=576, description="The tile height, in pixels.") + min_overlap: int = InputField(default=128, ge=0, description="Minimum overlap between adjacent tiles, in pixels.") + + def invoke(self, context: InvocationContext) -> CalculateImageTilesOutput: + tiles = calc_tiles_min_overlap( + image_height=self.image_height, + image_width=self.image_width, + tile_height=self.tile_height, + tile_width=self.tile_width, + min_overlap=self.min_overlap, + ) + return CalculateImageTilesOutput(tiles=tiles) + + @invocation_output("tile_to_properties_output") class TileToPropertiesOutput(BaseInvocationOutput): coords_left: int = OutputField(description="Left coordinate of the tile relative to its parent image.") @@ -121,13 +201,22 @@ class PairTileImageInvocation(BaseInvocation): ) +BLEND_MODES = Literal["Linear", "Seam"] + + @invocation("merge_tiles_to_image", title="Merge Tiles to Image", tags=["tiles"], category="tiles", version="1.1.0") class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): """Merge multiple tile images into a single image.""" # Inputs tiles_with_images: list[TileWithImage] = InputField(description="A list of tile images with tile properties.") + blend_mode: BLEND_MODES = InputField( + default="Seam", + description="blending type Linear or Seam", + input=Input.Direct, + ) blend_amount: int = InputField( + default=32, ge=0, description="The amount to blend adjacent tiles in pixels. Must be <= the amount of overlap between adjacent tiles.", ) @@ -157,10 +246,18 @@ class MergeTilesToImageInvocation(BaseInvocation, WithMetadata): channels = tile_np_images[0].shape[-1] dtype = tile_np_images[0].dtype np_image = np.zeros(shape=(height, width, channels), dtype=dtype) + if self.blend_mode == "Linear": + merge_tiles_with_linear_blending( + dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount + ) + elif self.blend_mode == "Seam": + merge_tiles_with_seam_blending( + dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount + ) + else: + raise ValueError(f"Unsupported blend mode: '{self.blend_mode}'.") - merge_tiles_with_linear_blending( - dst_image=np_image, tiles=tiles, tile_images=tile_np_images, blend_amount=self.blend_amount - ) + # Convert into a PIL image and save pil_image = Image.fromarray(np_image) image_dto = context.services.images.create( diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py index 3a678d825e..1948f6624e 100644 --- a/invokeai/backend/tiles/tiles.py +++ b/invokeai/backend/tiles/tiles.py @@ -3,7 +3,42 @@ from typing import Union import numpy as np -from invokeai.backend.tiles.utils import TBLR, Tile, paste +from invokeai.app.invocations.latent import LATENT_SCALE_FACTOR +from invokeai.backend.tiles.utils import TBLR, Tile, paste, seam_blend + + +def calc_overlap(tiles: list[Tile], num_tiles_x: int, num_tiles_y: int) -> list[Tile]: + """Calculate and update the overlap of a list of tiles. + + Args: + tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`. + num_tiles_x: the number of tiles on the x axis. + num_tiles_y: the number of tiles on the y axis. + """ + + def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]: + if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x: + return None + return tiles[idx_y * num_tiles_x + idx_x] + + for tile_idx_y in range(num_tiles_y): + for tile_idx_x in range(num_tiles_x): + cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x) + top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x) + left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1) + + assert cur_tile is not None + + # Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap. + if top_neighbor_tile is not None: + cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top) + top_neighbor_tile.overlap.bottom = cur_tile.overlap.top + + # Update cur_tile left-overlap and corresponding left-neighbor right-overlap. + if left_neighbor_tile is not None: + cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left) + left_neighbor_tile.overlap.right = cur_tile.overlap.left + return tiles def calc_tiles_with_overlap( @@ -63,31 +98,125 @@ def calc_tiles_with_overlap( tiles.append(tile) - def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]: - if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x: - return None - return tiles[idx_y * num_tiles_x + idx_x] + return calc_overlap(tiles, num_tiles_x, num_tiles_y) - # Iterate over tiles again and calculate overlaps. + +def calc_tiles_even_split( + image_height: int, image_width: int, num_tiles_x: int, num_tiles_y: int, overlap_fraction: float = 0 +) -> list[Tile]: + """Calculate the tile coordinates for a given image shape with the number of tiles requested. + + Args: + image_height (int): The image height in px. + image_width (int): The image width in px. + num_x_tiles (int): The number of tile to split the image into on the X-axis. + num_y_tiles (int): The number of tile to split the image into on the Y-axis. + overlap_fraction (float, optional): The target overlap as fraction of the tiles size. Defaults to 0. + + Returns: + list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom. + """ + + # Ensure tile size is divisible by 8 + if image_width % LATENT_SCALE_FACTOR != 0 or image_height % LATENT_SCALE_FACTOR != 0: + raise ValueError(f"image size (({image_width}, {image_height})) must be divisible by {LATENT_SCALE_FACTOR}") + + # Calculate the overlap size based on the percentage and adjust it to be divisible by 8 (rounding up) + overlap_x = LATENT_SCALE_FACTOR * math.ceil( + int((image_width / num_tiles_x) * overlap_fraction) / LATENT_SCALE_FACTOR + ) + overlap_y = LATENT_SCALE_FACTOR * math.ceil( + int((image_height / num_tiles_y) * overlap_fraction) / LATENT_SCALE_FACTOR + ) + + # Calculate the tile size based on the number of tiles and overlap, and ensure it's divisible by 8 (rounding down) + tile_size_x = LATENT_SCALE_FACTOR * math.floor( + ((image_width + overlap_x * (num_tiles_x - 1)) // num_tiles_x) / LATENT_SCALE_FACTOR + ) + tile_size_y = LATENT_SCALE_FACTOR * math.floor( + ((image_height + overlap_y * (num_tiles_y - 1)) // num_tiles_y) / LATENT_SCALE_FACTOR + ) + + # tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column. + tiles: list[Tile] = [] + + # Calculate tile coordinates. (Ignore overlap values for now.) for tile_idx_y in range(num_tiles_y): + # Calculate the top and bottom of the row + top = tile_idx_y * (tile_size_y - overlap_y) + bottom = min(top + tile_size_y, image_height) + # For the last row adjust bottom to be the height of the image + if tile_idx_y == num_tiles_y - 1: + bottom = image_height + for tile_idx_x in range(num_tiles_x): - cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x) - top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x) - left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1) + # Calculate the left & right coordinate of each tile + left = tile_idx_x * (tile_size_x - overlap_x) + right = min(left + tile_size_x, image_width) + # For the last tile in the row adjust right to be the width of the image + if tile_idx_x == num_tiles_x - 1: + right = image_width - assert cur_tile is not None + tile = Tile( + coords=TBLR(top=top, bottom=bottom, left=left, right=right), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) - # Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap. - if top_neighbor_tile is not None: - cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top) - top_neighbor_tile.overlap.bottom = cur_tile.overlap.top + tiles.append(tile) - # Update cur_tile left-overlap and corresponding left-neighbor right-overlap. - if left_neighbor_tile is not None: - cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left) - left_neighbor_tile.overlap.right = cur_tile.overlap.left + return calc_overlap(tiles, num_tiles_x, num_tiles_y) - return tiles + +def calc_tiles_min_overlap( + image_height: int, + image_width: int, + tile_height: int, + tile_width: int, + min_overlap: int = 0, +) -> list[Tile]: + """Calculate the tile coordinates for a given image shape under a simple tiling scheme with overlaps. + + Args: + image_height (int): The image height in px. + image_width (int): The image width in px. + tile_height (int): The tile height in px. All tiles will have this height. + tile_width (int): The tile width in px. All tiles will have this width. + min_overlap (int): The target minimum overlap between adjacent tiles. If the tiles do not evenly cover the image + shape, then the overlap will be spread between the tiles. + + Returns: + list[Tile]: A list of tiles that cover the image shape. Ordered from left-to-right, top-to-bottom. + """ + + assert min_overlap < tile_height + assert min_overlap < tile_width + + # The If Else catches the case when the tile size is larger than the images size and just clips the number of tiles to 1 + num_tiles_x = math.ceil((image_width - min_overlap) / (tile_width - min_overlap)) if tile_width < image_width else 1 + num_tiles_y = ( + math.ceil((image_height - min_overlap) / (tile_height - min_overlap)) if tile_height < image_height else 1 + ) + + # tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column. + tiles: list[Tile] = [] + + # Calculate tile coordinates. (Ignore overlap values for now.) + for tile_idx_y in range(num_tiles_y): + top = (tile_idx_y * (image_height - tile_height)) // (num_tiles_y - 1) if num_tiles_y > 1 else 0 + bottom = top + tile_height + + for tile_idx_x in range(num_tiles_x): + left = (tile_idx_x * (image_width - tile_width)) // (num_tiles_x - 1) if num_tiles_x > 1 else 0 + right = left + tile_width + + tile = Tile( + coords=TBLR(top=top, bottom=bottom, left=left, right=right), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + + tiles.append(tile) + + return calc_overlap(tiles, num_tiles_x, num_tiles_y) def merge_tiles_with_linear_blending( @@ -199,3 +328,91 @@ def merge_tiles_with_linear_blending( ), mask=mask, ) + + +def merge_tiles_with_seam_blending( + dst_image: np.ndarray, tiles: list[Tile], tile_images: list[np.ndarray], blend_amount: int +): + """Merge a set of image tiles into `dst_image` with seam blending between the tiles. + + We expect every tile edge to either: + 1) have an overlap of 0, because it is aligned with the image edge, or + 2) have an overlap >= blend_amount. + If neither of these conditions are satisfied, we raise an exception. + + The seam blending is centered on a seam of least energy of the overlap between adjacent tiles. + + Args: + dst_image (np.ndarray): The destination image. Shape: (H, W, C). + tiles (list[Tile]): The list of tiles describing the locations of the respective `tile_images`. + tile_images (list[np.ndarray]): The tile images to merge into `dst_image`. + blend_amount (int): The amount of blending (in px) between adjacent overlapping tiles. + """ + # Sort tiles and images first by left x coordinate, then by top y coordinate. During tile processing, we want to + # iterate over tiles left-to-right, top-to-bottom. + tiles_and_images = list(zip(tiles, tile_images, strict=True)) + tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.left) + tiles_and_images = sorted(tiles_and_images, key=lambda x: x[0].coords.top) + + # Organize tiles into rows. + tile_and_image_rows: list[list[tuple[Tile, np.ndarray]]] = [] + cur_tile_and_image_row: list[tuple[Tile, np.ndarray]] = [] + first_tile_in_cur_row, _ = tiles_and_images[0] + for tile_and_image in tiles_and_images: + tile, _ = tile_and_image + if not ( + tile.coords.top == first_tile_in_cur_row.coords.top + and tile.coords.bottom == first_tile_in_cur_row.coords.bottom + ): + # Store the previous row, and start a new one. + tile_and_image_rows.append(cur_tile_and_image_row) + cur_tile_and_image_row = [] + first_tile_in_cur_row, _ = tile_and_image + + cur_tile_and_image_row.append(tile_and_image) + tile_and_image_rows.append(cur_tile_and_image_row) + + for tile_and_image_row in tile_and_image_rows: + first_tile_in_row, _ = tile_and_image_row[0] + row_height = first_tile_in_row.coords.bottom - first_tile_in_row.coords.top + row_image = np.zeros((row_height, dst_image.shape[1], dst_image.shape[2]), dtype=dst_image.dtype) + + # Blend the tiles in the row horizontally. + for tile, tile_image in tile_and_image_row: + # We expect the tiles to be ordered left-to-right. + # For each tile: + # - extract the overlap regions and pass to seam_blend() + # - apply blended region to the row_image + # - apply the un-blended region to the row_image + tile_height, tile_width, _ = tile_image.shape + overlap_size = tile.overlap.left + # Left blending: + if overlap_size > 0: + assert overlap_size >= blend_amount + + overlap_coord_right = tile.coords.left + overlap_size + src_overlap = row_image[:, tile.coords.left : overlap_coord_right] + dst_overlap = tile_image[:, :overlap_size] + blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=False) + row_image[:, tile.coords.left : overlap_coord_right] = blended_overlap + row_image[:, overlap_coord_right : tile.coords.right] = tile_image[:, overlap_size:] + else: + # no overlap just paste the tile + row_image[:, tile.coords.left : tile.coords.right] = tile_image + + # Blend the row into the dst_image + # We assume that the entire row has the same vertical overlaps as the first_tile_in_row. + # Rows are processed in the same way as tiles (extract overlap, blend, apply) + row_overlap_size = first_tile_in_row.overlap.top + if row_overlap_size > 0: + assert row_overlap_size >= blend_amount + + overlap_coords_bottom = first_tile_in_row.coords.top + row_overlap_size + src_overlap = dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] + dst_overlap = row_image[:row_overlap_size, :] + blended_overlap = seam_blend(src_overlap, dst_overlap, blend_amount, x_seam=True) + dst_image[first_tile_in_row.coords.top : overlap_coords_bottom, :] = blended_overlap + dst_image[overlap_coords_bottom : first_tile_in_row.coords.bottom, :] = row_image[row_overlap_size:, :] + else: + # no overlap just paste the row + dst_image[first_tile_in_row.coords.top : first_tile_in_row.coords.bottom, :] = row_image diff --git a/invokeai/backend/tiles/utils.py b/invokeai/backend/tiles/utils.py index 4ad40ffa35..dc6d914170 100644 --- a/invokeai/backend/tiles/utils.py +++ b/invokeai/backend/tiles/utils.py @@ -1,5 +1,7 @@ +import math from typing import Optional +import cv2 import numpy as np from pydantic import BaseModel, Field @@ -31,10 +33,10 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona """Paste a source image into a destination image. Args: - dst_image (torch.Tensor): The destination image to paste into. Shape: (H, W, C). - src_image (torch.Tensor): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'. + dst_image (np.array): The destination image to paste into. Shape: (H, W, C). + src_image (np.array): The source image to paste. Shape: (H, W, C). H and W must be compatible with 'box'. box (TBLR): Box defining the region in the 'dst_image' where 'src_image' will be pasted. - mask (Optional[torch.Tensor]): A mask that defines the blending between 'src_image' and 'dst_image'. + mask (Optional[np.array]): A mask that defines the blending between 'src_image' and 'dst_image'. Range: [0.0, 1.0], Shape: (H, W). The output is calculate per-pixel according to `src * mask + dst * (1 - mask)`. """ @@ -45,3 +47,106 @@ def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optiona mask = np.expand_dims(mask, -1) dst_image_box = dst_image[box.top : box.bottom, box.left : box.right] dst_image[box.top : box.bottom, box.left : box.right] = src_image * mask + dst_image_box * (1.0 - mask) + + +def seam_blend(ia1: np.ndarray, ia2: np.ndarray, blend_amount: int, x_seam: bool) -> np.ndarray: + """Blend two overlapping tile sections using a seams to find a path. + + It is assumed that input images will be RGB np arrays and are the same size. + + Args: + ia1 (np.array): Image array 1 Shape: (H, W, C). + ia2 (np.array): Image array 2 Shape: (H, W, C). + x_seam (bool): If the images should be blended on the x axis or not. + blend_amount (int): The size of the blur to use on the seam. Half of this value will be used to avoid the edges of the image. + """ + assert ia1.shape == ia2.shape + assert ia2.size == ia2.size + + def shift(arr, num, fill_value=255.0): + result = np.full_like(arr, fill_value) + if num > 0: + result[num:] = arr[:-num] + elif num < 0: + result[:num] = arr[-num:] + else: + result[:] = arr + return result + + # Assume RGB and convert to grey + # Could offer other options for the luminance conversion + # BT.709 [0.2126, 0.7152, 0.0722], BT.2020 [0.2627, 0.6780, 0.0593]) + # it might not have a huge impact due to the blur that is applied over the seam + iag1 = np.dot(ia1, [0.2989, 0.5870, 0.1140]) # BT.601 perceived brightness + iag2 = np.dot(ia2, [0.2989, 0.5870, 0.1140]) + + # Calc Difference between the images + ia = iag2 - iag1 + + # If the seam is on the X-axis rotate the array so we can treat it like a vertical seam + if x_seam: + ia = np.rot90(ia, 1) + + # Calc max and min X & Y limits + # gutter is used to avoid the blur hitting the edge of the image + gutter = math.ceil(blend_amount / 2) if blend_amount > 0 else 0 + max_y, max_x = ia.shape + max_x -= gutter + min_x = gutter + + # Calc the energy in the difference + # Could offer different energy calculations e.g. Sobel or Scharr + energy = np.abs(np.gradient(ia, axis=0)) + np.abs(np.gradient(ia, axis=1)) + + # Find the starting position of the seam + res = np.copy(energy) + for y in range(1, max_y): + row = res[y, :] + rowl = shift(row, -1) + rowr = shift(row, 1) + res[y, :] = res[y - 1, :] + np.min([row, rowl, rowr], axis=0) + + # create an array max_y long + lowest_energy_line = np.empty([max_y], dtype="uint16") + lowest_energy_line[max_y - 1] = np.argmin(res[max_y - 1, min_x : max_x - 1]) + + # Calc the path of the seam + # could offer options for larger search than just 1 pixel by adjusting lpos and rpos + for ypos in range(max_y - 2, -1, -1): + lowest_pos = lowest_energy_line[ypos + 1] + lpos = lowest_pos - 1 + rpos = lowest_pos + 1 + lpos = np.clip(lpos, min_x, max_x - 1) + rpos = np.clip(rpos, min_x, max_x - 1) + lowest_energy_line[ypos] = np.argmin(energy[ypos, lpos : rpos + 1]) + lpos + + # Draw the mask + mask = np.zeros_like(ia) + for ypos in range(0, max_y): + to_fill = lowest_energy_line[ypos] + mask[ypos, :to_fill] = 1 + + # If the seam is on the X-axis rotate the array back + if x_seam: + mask = np.rot90(mask, 3) + + # blur the seam mask if required + if blend_amount > 0: + mask = cv2.blur(mask, (blend_amount, blend_amount)) + + # for visual debugging + # from PIL import Image + # m_image = Image.fromarray((mask * 255.0).astype("uint8")) + + # copy ia2 over ia1 while applying the seam mask + mask = np.expand_dims(mask, -1) + blended_image = ia1 * mask + ia2 * (1.0 - mask) + + # for visual debugging + # i1 = Image.fromarray(ia1.astype("uint8")) + # i2 = Image.fromarray(ia2.astype("uint8")) + # b_image = Image.fromarray(blended_image.astype("uint8")) + # print(f"{ia1.shape}, {ia2.shape}, {mask.shape}, {blended_image.shape}") + # print(f"{i1.size}, {i2.size}, {m_image.size}, {b_image.size}") + + return blended_image diff --git a/tests/backend/tiles/test_tiles.py b/tests/backend/tiles/test_tiles.py index 353e65d336..0b18f9ed54 100644 --- a/tests/backend/tiles/test_tiles.py +++ b/tests/backend/tiles/test_tiles.py @@ -1,7 +1,12 @@ import numpy as np import pytest -from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending +from invokeai.backend.tiles.tiles import ( + calc_tiles_even_split, + calc_tiles_min_overlap, + calc_tiles_with_overlap, + merge_tiles_with_linear_blending, +) from invokeai.backend.tiles.utils import TBLR, Tile #################################### @@ -14,7 +19,10 @@ def test_calc_tiles_with_overlap_single_tile(): tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64) expected_tiles = [ - Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0)) + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=1024), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) ] assert tiles == expected_tiles @@ -27,13 +35,31 @@ def test_calc_tiles_with_overlap_evenly_divisible(): expected_tiles = [ # Row 0 - Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)), - Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)), - Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)), + Tile( + coords=TBLR(top=0, bottom=320, left=0, right=576), + overlap=TBLR(top=0, bottom=64, left=0, right=64), + ), + Tile( + coords=TBLR(top=0, bottom=320, left=512, right=1088), + overlap=TBLR(top=0, bottom=64, left=64, right=64), + ), + Tile( + coords=TBLR(top=0, bottom=320, left=1024, right=1600), + overlap=TBLR(top=0, bottom=64, left=64, right=0), + ), # Row 1 - Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)), - Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)), - Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)), + Tile( + coords=TBLR(top=256, bottom=576, left=0, right=576), + overlap=TBLR(top=64, bottom=0, left=0, right=64), + ), + Tile( + coords=TBLR(top=256, bottom=576, left=512, right=1088), + overlap=TBLR(top=64, bottom=0, left=64, right=64), + ), + Tile( + coords=TBLR(top=256, bottom=576, left=1024, right=1600), + overlap=TBLR(top=64, bottom=0, left=64, right=0), + ), ] assert tiles == expected_tiles @@ -46,16 +72,30 @@ def test_calc_tiles_with_overlap_not_evenly_divisible(): expected_tiles = [ # Row 0 - Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)), - Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)), - Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)), - # Row 1 - Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)), Tile( - coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272) + coords=TBLR(top=0, bottom=256, left=0, right=512), + overlap=TBLR(top=0, bottom=112, left=0, right=64), ), Tile( - coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0) + coords=TBLR(top=0, bottom=256, left=448, right=960), + overlap=TBLR(top=0, bottom=112, left=64, right=272), + ), + Tile( + coords=TBLR(top=0, bottom=256, left=688, right=1200), + overlap=TBLR(top=0, bottom=112, left=272, right=0), + ), + # Row 1 + Tile( + coords=TBLR(top=144, bottom=400, left=0, right=512), + overlap=TBLR(top=112, bottom=0, left=0, right=64), + ), + Tile( + coords=TBLR(top=144, bottom=400, left=448, right=960), + overlap=TBLR(top=112, bottom=0, left=64, right=272), + ), + Tile( + coords=TBLR(top=144, bottom=400, left=688, right=1200), + overlap=TBLR(top=112, bottom=0, left=272, right=0), ), ] @@ -75,7 +115,12 @@ def test_calc_tiles_with_overlap_not_evenly_divisible(): ], ) def test_calc_tiles_with_overlap_input_validation( - image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool + image_height: int, + image_width: int, + tile_height: int, + tile_width: int, + overlap: int, + raises: bool, ): """Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid.""" if raises: @@ -85,6 +130,306 @@ def test_calc_tiles_with_overlap_input_validation( calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap) +#################################### +# Test calc_tiles_min_overlap(...) +#################################### + + +def test_calc_tiles_min_overlap_single_tile(): + """Test calc_tiles_min_overlap() behavior when a single tile covers the image.""" + tiles = calc_tiles_min_overlap( + image_height=512, + image_width=1024, + tile_height=512, + tile_width=1024, + min_overlap=64, + ) + + expected_tiles = [ + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=1024), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_min_overlap_evenly_divisible(): + """Test calc_tiles_min_overlap() behavior when the image is evenly covered by multiple tiles.""" + # Parameters mimic roughly the same output as the original tile generations of the same test name + tiles = calc_tiles_min_overlap( + image_height=576, + image_width=1600, + tile_height=320, + tile_width=576, + min_overlap=64, + ) + + expected_tiles = [ + # Row 0 + Tile( + coords=TBLR(top=0, bottom=320, left=0, right=576), + overlap=TBLR(top=0, bottom=64, left=0, right=64), + ), + Tile( + coords=TBLR(top=0, bottom=320, left=512, right=1088), + overlap=TBLR(top=0, bottom=64, left=64, right=64), + ), + Tile( + coords=TBLR(top=0, bottom=320, left=1024, right=1600), + overlap=TBLR(top=0, bottom=64, left=64, right=0), + ), + # Row 1 + Tile( + coords=TBLR(top=256, bottom=576, left=0, right=576), + overlap=TBLR(top=64, bottom=0, left=0, right=64), + ), + Tile( + coords=TBLR(top=256, bottom=576, left=512, right=1088), + overlap=TBLR(top=64, bottom=0, left=64, right=64), + ), + Tile( + coords=TBLR(top=256, bottom=576, left=1024, right=1600), + overlap=TBLR(top=64, bottom=0, left=64, right=0), + ), + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_min_overlap_not_evenly_divisible(): + """Test calc_tiles_min_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage.""" + # Parameters mimic roughly the same output as the original tile generations of the same test name + tiles = calc_tiles_min_overlap( + image_height=400, + image_width=1200, + tile_height=256, + tile_width=512, + min_overlap=64, + ) + + expected_tiles = [ + # Row 0 + Tile( + coords=TBLR(top=0, bottom=256, left=0, right=512), + overlap=TBLR(top=0, bottom=112, left=0, right=168), + ), + Tile( + coords=TBLR(top=0, bottom=256, left=344, right=856), + overlap=TBLR(top=0, bottom=112, left=168, right=168), + ), + Tile( + coords=TBLR(top=0, bottom=256, left=688, right=1200), + overlap=TBLR(top=0, bottom=112, left=168, right=0), + ), + # Row 1 + Tile( + coords=TBLR(top=144, bottom=400, left=0, right=512), + overlap=TBLR(top=112, bottom=0, left=0, right=168), + ), + Tile( + coords=TBLR(top=144, bottom=400, left=344, right=856), + overlap=TBLR(top=112, bottom=0, left=168, right=168), + ), + Tile( + coords=TBLR(top=144, bottom=400, left=688, right=1200), + overlap=TBLR(top=112, bottom=0, left=168, right=0), + ), + ] + + assert tiles == expected_tiles + + +@pytest.mark.parametrize( + [ + "image_height", + "image_width", + "tile_height", + "tile_width", + "min_overlap", + "raises", + ], + [ + (128, 128, 128, 128, 127, False), # OK + (128, 128, 128, 128, 0, False), # OK + (128, 128, 64, 64, 0, False), # OK + (128, 128, 129, 128, 0, False), # tile_height exceeds image_height defaults to 1 tile. + (128, 128, 128, 129, 0, False), # tile_width exceeds image_width defaults to 1 tile. + (128, 128, 64, 128, 64, True), # overlap equals tile_height. + (128, 128, 128, 64, 64, True), # overlap equals tile_width. + ], +) +def test_calc_tiles_min_overlap_input_validation( + image_height: int, + image_width: int, + tile_height: int, + tile_width: int, + min_overlap: int, + raises: bool, +): + """Test that calc_tiles_min_overlap() raises an exception if the inputs are invalid.""" + if raises: + with pytest.raises(AssertionError): + calc_tiles_min_overlap(image_height, image_width, tile_height, tile_width, min_overlap) + else: + calc_tiles_min_overlap(image_height, image_width, tile_height, tile_width, min_overlap) + + +#################################### +# Test calc_tiles_even_split(...) +#################################### + + +def test_calc_tiles_even_split_single_tile(): + """Test calc_tiles_even_split() behavior when a single tile covers the image.""" + tiles = calc_tiles_even_split( + image_height=512, image_width=1024, num_tiles_x=1, num_tiles_y=1, overlap_fraction=0.25 + ) + + expected_tiles = [ + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=1024), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_even_split_evenly_divisible(): + """Test calc_tiles_even_split() behavior when the image is evenly covered by multiple tiles.""" + # Parameters mimic roughly the same output as the original tile generations of the same test name + tiles = calc_tiles_even_split( + image_height=576, image_width=1600, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25 + ) + + expected_tiles = [ + # Row 0 + Tile( + coords=TBLR(top=0, bottom=320, left=0, right=624), + overlap=TBLR(top=0, bottom=72, left=0, right=136), + ), + Tile( + coords=TBLR(top=0, bottom=320, left=488, right=1112), + overlap=TBLR(top=0, bottom=72, left=136, right=136), + ), + Tile( + coords=TBLR(top=0, bottom=320, left=976, right=1600), + overlap=TBLR(top=0, bottom=72, left=136, right=0), + ), + # Row 1 + Tile( + coords=TBLR(top=248, bottom=576, left=0, right=624), + overlap=TBLR(top=72, bottom=0, left=0, right=136), + ), + Tile( + coords=TBLR(top=248, bottom=576, left=488, right=1112), + overlap=TBLR(top=72, bottom=0, left=136, right=136), + ), + Tile( + coords=TBLR(top=248, bottom=576, left=976, right=1600), + overlap=TBLR(top=72, bottom=0, left=136, right=0), + ), + ] + assert tiles == expected_tiles + + +def test_calc_tiles_even_split_not_evenly_divisible(): + """Test calc_tiles_even_split() behavior when the image requires 'uneven' overlaps to achieve proper coverage.""" + # Parameters mimic roughly the same output as the original tile generations of the same test name + tiles = calc_tiles_even_split( + image_height=400, image_width=1200, num_tiles_x=3, num_tiles_y=2, overlap_fraction=0.25 + ) + + expected_tiles = [ + # Row 0 + Tile( + coords=TBLR(top=0, bottom=224, left=0, right=464), + overlap=TBLR(top=0, bottom=56, left=0, right=104), + ), + Tile( + coords=TBLR(top=0, bottom=224, left=360, right=824), + overlap=TBLR(top=0, bottom=56, left=104, right=104), + ), + Tile( + coords=TBLR(top=0, bottom=224, left=720, right=1200), + overlap=TBLR(top=0, bottom=56, left=104, right=0), + ), + # Row 1 + Tile( + coords=TBLR(top=168, bottom=400, left=0, right=464), + overlap=TBLR(top=56, bottom=0, left=0, right=104), + ), + Tile( + coords=TBLR(top=168, bottom=400, left=360, right=824), + overlap=TBLR(top=56, bottom=0, left=104, right=104), + ), + Tile( + coords=TBLR(top=168, bottom=400, left=720, right=1200), + overlap=TBLR(top=56, bottom=0, left=104, right=0), + ), + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_even_split_difficult_size(): + """Test calc_tiles_even_split() behavior when the image is a difficult size to spilt evenly and keep div8.""" + # Parameters are a difficult size for other tile gen routines to calculate + tiles = calc_tiles_even_split( + image_height=1000, image_width=1000, num_tiles_x=2, num_tiles_y=2, overlap_fraction=0.25 + ) + + expected_tiles = [ + # Row 0 + Tile( + coords=TBLR(top=0, bottom=560, left=0, right=560), + overlap=TBLR(top=0, bottom=128, left=0, right=128), + ), + Tile( + coords=TBLR(top=0, bottom=560, left=432, right=1000), + overlap=TBLR(top=0, bottom=128, left=128, right=0), + ), + # Row 1 + Tile( + coords=TBLR(top=432, bottom=1000, left=0, right=560), + overlap=TBLR(top=128, bottom=0, left=0, right=128), + ), + Tile( + coords=TBLR(top=432, bottom=1000, left=432, right=1000), + overlap=TBLR(top=128, bottom=0, left=128, right=0), + ), + ] + + assert tiles == expected_tiles + + +@pytest.mark.parametrize( + ["image_height", "image_width", "num_tiles_x", "num_tiles_y", "overlap_fraction", "raises"], + [ + (128, 128, 1, 1, 0.25, False), # OK + (128, 128, 1, 1, 0, False), # OK + (128, 128, 2, 1, 0, False), # OK + (127, 127, 1, 1, 0, True), # image size must be dividable by 8 + ], +) +def test_calc_tiles_even_split_input_validation( + image_height: int, + image_width: int, + num_tiles_x: int, + num_tiles_y: int, + overlap_fraction: float, + raises: bool, +): + """Test that calc_tiles_even_split() raises an exception if the inputs are invalid.""" + if raises: + with pytest.raises(ValueError): + calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction) + else: + calc_tiles_even_split(image_height, image_width, num_tiles_x, num_tiles_y, overlap_fraction) + + ############################################# # Test merge_tiles_with_linear_blending(...) ############################################# @@ -95,8 +440,14 @@ def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int): """Test merge_tiles_with_linear_blending(...) behavior when merging horizontally.""" # Initialize 2 tiles side-by-side. tiles = [ - Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)), - Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)), + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=512), + overlap=TBLR(top=0, bottom=0, left=0, right=64), + ), + Tile( + coords=TBLR(top=0, bottom=512, left=448, right=960), + overlap=TBLR(top=0, bottom=0, left=64, right=0), + ), ] dst_image = np.zeros((512, 960, 3), dtype=np.uint8) @@ -116,7 +467,10 @@ def test_merge_tiles_with_linear_blending_horizontal(blend_amount: int): expected_output[:, 480 + (blend_amount // 2) :, :] = 128 merge_tiles_with_linear_blending( - dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount + dst_image=dst_image, + tiles=tiles, + tile_images=tile_images, + blend_amount=blend_amount, ) np.testing.assert_array_equal(dst_image, expected_output, strict=True) @@ -127,8 +481,14 @@ def test_merge_tiles_with_linear_blending_vertical(blend_amount: int): """Test merge_tiles_with_linear_blending(...) behavior when merging vertically.""" # Initialize 2 tiles stacked vertically. tiles = [ - Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)), - Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)), + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=512), + overlap=TBLR(top=0, bottom=64, left=0, right=0), + ), + Tile( + coords=TBLR(top=448, bottom=960, left=0, right=512), + overlap=TBLR(top=64, bottom=0, left=0, right=0), + ), ] dst_image = np.zeros((960, 512, 3), dtype=np.uint8) @@ -148,7 +508,10 @@ def test_merge_tiles_with_linear_blending_vertical(blend_amount: int): expected_output[480 + (blend_amount // 2) :, :, :] = 128 merge_tiles_with_linear_blending( - dst_image=dst_image, tiles=tiles, tile_images=tile_images, blend_amount=blend_amount + dst_image=dst_image, + tiles=tiles, + tile_images=tile_images, + blend_amount=blend_amount, ) np.testing.assert_array_equal(dst_image, expected_output, strict=True) @@ -160,8 +523,14 @@ def test_merge_tiles_with_linear_blending_blend_amount_exceeds_vertical_overlap( """ # Initialize 2 tiles stacked vertically. tiles = [ - Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=64, left=0, right=0)), - Tile(coords=TBLR(top=448, bottom=960, left=0, right=512), overlap=TBLR(top=64, bottom=0, left=0, right=0)), + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=512), + overlap=TBLR(top=0, bottom=64, left=0, right=0), + ), + Tile( + coords=TBLR(top=448, bottom=960, left=0, right=512), + overlap=TBLR(top=64, bottom=0, left=0, right=0), + ), ] dst_image = np.zeros((960, 512, 3), dtype=np.uint8) @@ -180,8 +549,14 @@ def test_merge_tiles_with_linear_blending_blend_amount_exceeds_horizontal_overla """ # Initialize 2 tiles side-by-side. tiles = [ - Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=64)), - Tile(coords=TBLR(top=0, bottom=512, left=448, right=960), overlap=TBLR(top=0, bottom=0, left=64, right=0)), + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=512), + overlap=TBLR(top=0, bottom=0, left=0, right=64), + ), + Tile( + coords=TBLR(top=0, bottom=512, left=448, right=960), + overlap=TBLR(top=0, bottom=0, left=64, right=0), + ), ] dst_image = np.zeros((512, 960, 3), dtype=np.uint8) @@ -198,7 +573,12 @@ def test_merge_tiles_with_linear_blending_tiles_overflow_dst_image(): """Test that merge_tiles_with_linear_blending(...) raises an exception if any of the tiles overflows the dst_image. """ - tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))] + tiles = [ + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=512), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + ] dst_image = np.zeros((256, 512, 3), dtype=np.uint8) @@ -213,7 +593,12 @@ def test_merge_tiles_with_linear_blending_mismatched_list_lengths(): """Test that merge_tiles_with_linear_blending(...) raises an exception if the lengths of 'tiles' and 'tile_images' do not match. """ - tiles = [Tile(coords=TBLR(top=0, bottom=512, left=0, right=512), overlap=TBLR(top=0, bottom=0, left=0, right=0))] + tiles = [ + Tile( + coords=TBLR(top=0, bottom=512, left=0, right=512), + overlap=TBLR(top=0, bottom=0, left=0, right=0), + ) + ] dst_image = np.zeros((256, 512, 3), dtype=np.uint8)