diff --git a/invokeai/backend/tiles/tiles.py b/invokeai/backend/tiles/tiles.py index 566381d1ff..5e5c4b7050 100644 --- a/invokeai/backend/tiles/tiles.py +++ b/invokeai/backend/tiles/tiles.py @@ -1,17 +1,10 @@ import math +from typing import Union import numpy as np from invokeai.backend.tiles.utils import TBLR, Tile, paste -# TODO(ryand) -# Test the following: -# - Tile too big in x, y -# - Overlap too big in x, y -# - Single tile fits -# - Multiple tiles fit perfectly -# - Not evenly divisible by tile size(with overlap) - def calc_tiles_with_overlap( image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0 @@ -40,8 +33,10 @@ def calc_tiles_with_overlap( num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height) num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width) - # Calculate tile coordinates and overlaps. + # tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column. tiles: list[Tile] = [] + + # Calculate tile coordinates. (Ignore overlap values for now.) for tile_idx_y in range(num_tiles_y): for tile_idx_x in range(num_tiles_x): tile = Tile( @@ -51,12 +46,7 @@ def calc_tiles_with_overlap( left=tile_idx_x * non_overlap_per_tile_width, right=tile_idx_x * non_overlap_per_tile_width + tile_width, ), - overlap=TBLR( - top=0 if tile_idx_y == 0 else overlap, - bottom=overlap, - left=0 if tile_idx_x == 0 else overlap, - right=overlap, - ), + overlap=TBLR(top=0, bottom=0, left=0, right=0), ) if tile.coords.bottom > image_height: @@ -64,23 +54,39 @@ def calc_tiles_with_overlap( # of the image. tile.coords.bottom = image_height tile.coords.top = image_height - tile_height - tile.overlap.bottom = 0 - # Note that this could result in a large overlap between this tile and the one above it. - top_neighbor_bottom = (tile_idx_y - 1) * non_overlap_per_tile_height + tile_height - tile.overlap.top = top_neighbor_bottom - tile.coords.top if tile.coords.right > image_width: # If this tile would go off the right edge of the image, shift it so that it is aligned with the # right edge of the image. tile.coords.right = image_width tile.coords.left = image_width - tile_width - tile.overlap.right = 0 - # Note that this could result in a large overlap between this tile and the one to its left. - left_neighbor_right = (tile_idx_x - 1) * non_overlap_per_tile_width + tile_width - tile.overlap.left = left_neighbor_right - tile.coords.left tiles.append(tile) + def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]: + if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x: + return None + return tiles[idx_y * num_tiles_x + idx_x] + + # Iterate over tiles again and calculate overlaps. + for tile_idx_y in range(num_tiles_y): + for tile_idx_x in range(num_tiles_x): + cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x) + top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x) + left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1) + + assert cur_tile is not None + + # Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap. + if top_neighbor_tile is not None: + cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top) + top_neighbor_tile.overlap.bottom = cur_tile.overlap.top + + # Update cur_tile left-overlap and corresponding left-neighbor right-overlap. + if left_neighbor_tile is not None: + cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left) + left_neighbor_tile.overlap.right = cur_tile.overlap.left + return tiles diff --git a/invokeai/backend/tiles/utils.py b/invokeai/backend/tiles/utils.py index cf8e926aa5..4ad40ffa35 100644 --- a/invokeai/backend/tiles/utils.py +++ b/invokeai/backend/tiles/utils.py @@ -10,11 +10,22 @@ class TBLR(BaseModel): left: int right: int + def __eq__(self, other): + return ( + self.top == other.top + and self.bottom == other.bottom + and self.left == other.left + and self.right == other.right + ) + class Tile(BaseModel): coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.") overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.") + def __eq__(self, other): + return self.coords == other.coords and self.overlap == other.overlap + def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None): """Paste a source image into a destination image. diff --git a/tests/backend/tiles/test_tiles.py b/tests/backend/tiles/test_tiles.py new file mode 100644 index 0000000000..332ab15005 --- /dev/null +++ b/tests/backend/tiles/test_tiles.py @@ -0,0 +1,84 @@ +import pytest + +from invokeai.backend.tiles.tiles import calc_tiles_with_overlap +from invokeai.backend.tiles.utils import TBLR, Tile + +#################################### +# Test calc_tiles_with_overlap(...) +#################################### + + +def test_calc_tiles_with_overlap_single_tile(): + """Test calc_tiles_with_overlap() behavior when a single tile covers the image.""" + tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64) + + expected_tiles = [ + Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0)) + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_with_overlap_evenly_divisible(): + """Test calc_tiles_with_overlap() behavior when the image is evenly covered by multiple tiles.""" + # Parameters chosen so that image is evenly covered by 2 rows, 3 columns of tiles. + tiles = calc_tiles_with_overlap(image_height=576, image_width=1600, tile_height=320, tile_width=576, overlap=64) + + expected_tiles = [ + # Row 0 + Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)), + Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)), + # Row 1 + Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)), + Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)), + Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)), + ] + + assert tiles == expected_tiles + + +def test_calc_tiles_with_overlap_not_evenly_divisible(): + """Test calc_tiles_with_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage.""" + # Parameters chosen so that image is covered by 2 rows and 3 columns of tiles, with uneven overlaps. + tiles = calc_tiles_with_overlap(image_height=400, image_width=1200, tile_height=256, tile_width=512, overlap=64) + + expected_tiles = [ + # Row 0 + Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)), + Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)), + Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)), + # Row 1 + Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)), + Tile( + coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272) + ), + Tile( + coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0) + ), + ] + + assert tiles == expected_tiles + + +@pytest.mark.parametrize( + ["image_height", "image_width", "tile_height", "tile_width", "overlap", "raises"], + [ + (128, 128, 128, 128, 127, False), # OK + (128, 128, 128, 128, 0, False), # OK + (128, 128, 64, 64, 0, False), # OK + (128, 128, 129, 128, 0, True), # tile_height exceeds image_height. + (128, 128, 128, 129, 0, True), # tile_width exceeds image_width. + (128, 128, 64, 128, 64, True), # overlap equals tile_height. + (128, 128, 128, 64, 64, True), # overlap equals tile_width. + ], +) +def test_calc_tiles_with_overlap_input_validation( + image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool +): + """Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid.""" + if raises: + with pytest.raises(AssertionError): + calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap) + else: + calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)