Add unit tests for calc_tiles_with_overlap(...) and fix a bug in its implementation.

This commit is contained in:
Ryan Dick 2023-11-20 14:23:49 -05:00
parent caf47dee09
commit 1f63fa8236
3 changed files with 124 additions and 23 deletions

View File

@ -1,17 +1,10 @@
import math import math
from typing import Union
import numpy as np import numpy as np
from invokeai.backend.tiles.utils import TBLR, Tile, paste from invokeai.backend.tiles.utils import TBLR, Tile, paste
# TODO(ryand)
# Test the following:
# - Tile too big in x, y
# - Overlap too big in x, y
# - Single tile fits
# - Multiple tiles fit perfectly
# - Not evenly divisible by tile size(with overlap)
def calc_tiles_with_overlap( def calc_tiles_with_overlap(
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0 image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0
@ -40,8 +33,10 @@ def calc_tiles_with_overlap(
num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height) num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height)
num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width) num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width)
# Calculate tile coordinates and overlaps. # tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
tiles: list[Tile] = [] tiles: list[Tile] = []
# Calculate tile coordinates. (Ignore overlap values for now.)
for tile_idx_y in range(num_tiles_y): for tile_idx_y in range(num_tiles_y):
for tile_idx_x in range(num_tiles_x): for tile_idx_x in range(num_tiles_x):
tile = Tile( tile = Tile(
@ -51,12 +46,7 @@ def calc_tiles_with_overlap(
left=tile_idx_x * non_overlap_per_tile_width, left=tile_idx_x * non_overlap_per_tile_width,
right=tile_idx_x * non_overlap_per_tile_width + tile_width, right=tile_idx_x * non_overlap_per_tile_width + tile_width,
), ),
overlap=TBLR( overlap=TBLR(top=0, bottom=0, left=0, right=0),
top=0 if tile_idx_y == 0 else overlap,
bottom=overlap,
left=0 if tile_idx_x == 0 else overlap,
right=overlap,
),
) )
if tile.coords.bottom > image_height: if tile.coords.bottom > image_height:
@ -64,23 +54,39 @@ def calc_tiles_with_overlap(
# of the image. # of the image.
tile.coords.bottom = image_height tile.coords.bottom = image_height
tile.coords.top = image_height - tile_height tile.coords.top = image_height - tile_height
tile.overlap.bottom = 0
# Note that this could result in a large overlap between this tile and the one above it.
top_neighbor_bottom = (tile_idx_y - 1) * non_overlap_per_tile_height + tile_height
tile.overlap.top = top_neighbor_bottom - tile.coords.top
if tile.coords.right > image_width: if tile.coords.right > image_width:
# If this tile would go off the right edge of the image, shift it so that it is aligned with the # If this tile would go off the right edge of the image, shift it so that it is aligned with the
# right edge of the image. # right edge of the image.
tile.coords.right = image_width tile.coords.right = image_width
tile.coords.left = image_width - tile_width tile.coords.left = image_width - tile_width
tile.overlap.right = 0
# Note that this could result in a large overlap between this tile and the one to its left.
left_neighbor_right = (tile_idx_x - 1) * non_overlap_per_tile_width + tile_width
tile.overlap.left = left_neighbor_right - tile.coords.left
tiles.append(tile) tiles.append(tile)
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
return None
return tiles[idx_y * num_tiles_x + idx_x]
# Iterate over tiles again and calculate overlaps.
for tile_idx_y in range(num_tiles_y):
for tile_idx_x in range(num_tiles_x):
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
assert cur_tile is not None
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
if top_neighbor_tile is not None:
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
if left_neighbor_tile is not None:
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
left_neighbor_tile.overlap.right = cur_tile.overlap.left
return tiles return tiles

View File

@ -10,11 +10,22 @@ class TBLR(BaseModel):
left: int left: int
right: int right: int
def __eq__(self, other):
return (
self.top == other.top
and self.bottom == other.bottom
and self.left == other.left
and self.right == other.right
)
class Tile(BaseModel): class Tile(BaseModel):
coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.") coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.")
overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.") overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.")
def __eq__(self, other):
return self.coords == other.coords and self.overlap == other.overlap
def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None): def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None):
"""Paste a source image into a destination image. """Paste a source image into a destination image.

View File

@ -0,0 +1,84 @@
import pytest
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap
from invokeai.backend.tiles.utils import TBLR, Tile
####################################
# Test calc_tiles_with_overlap(...)
####################################
def test_calc_tiles_with_overlap_single_tile():
"""Test calc_tiles_with_overlap() behavior when a single tile covers the image."""
tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64)
expected_tiles = [
Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0))
]
assert tiles == expected_tiles
def test_calc_tiles_with_overlap_evenly_divisible():
"""Test calc_tiles_with_overlap() behavior when the image is evenly covered by multiple tiles."""
# Parameters chosen so that image is evenly covered by 2 rows, 3 columns of tiles.
tiles = calc_tiles_with_overlap(image_height=576, image_width=1600, tile_height=320, tile_width=576, overlap=64)
expected_tiles = [
# Row 0
Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)),
Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)),
Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)),
# Row 1
Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)),
Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)),
Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)),
]
assert tiles == expected_tiles
def test_calc_tiles_with_overlap_not_evenly_divisible():
"""Test calc_tiles_with_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
# Parameters chosen so that image is covered by 2 rows and 3 columns of tiles, with uneven overlaps.
tiles = calc_tiles_with_overlap(image_height=400, image_width=1200, tile_height=256, tile_width=512, overlap=64)
expected_tiles = [
# Row 0
Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)),
Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)),
Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)),
# Row 1
Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)),
Tile(
coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272)
),
Tile(
coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0)
),
]
assert tiles == expected_tiles
@pytest.mark.parametrize(
["image_height", "image_width", "tile_height", "tile_width", "overlap", "raises"],
[
(128, 128, 128, 128, 127, False), # OK
(128, 128, 128, 128, 0, False), # OK
(128, 128, 64, 64, 0, False), # OK
(128, 128, 129, 128, 0, True), # tile_height exceeds image_height.
(128, 128, 128, 129, 0, True), # tile_width exceeds image_width.
(128, 128, 64, 128, 64, True), # overlap equals tile_height.
(128, 128, 128, 64, 64, True), # overlap equals tile_width.
],
)
def test_calc_tiles_with_overlap_input_validation(
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool
):
"""Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid."""
if raises:
with pytest.raises(AssertionError):
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)
else:
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)