mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add unit tests for calc_tiles_with_overlap(...) and fix a bug in its implementation.
This commit is contained in:
parent
caf47dee09
commit
1f63fa8236
@ -1,17 +1,10 @@
|
|||||||
import math
|
import math
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from invokeai.backend.tiles.utils import TBLR, Tile, paste
|
from invokeai.backend.tiles.utils import TBLR, Tile, paste
|
||||||
|
|
||||||
# TODO(ryand)
|
|
||||||
# Test the following:
|
|
||||||
# - Tile too big in x, y
|
|
||||||
# - Overlap too big in x, y
|
|
||||||
# - Single tile fits
|
|
||||||
# - Multiple tiles fit perfectly
|
|
||||||
# - Not evenly divisible by tile size(with overlap)
|
|
||||||
|
|
||||||
|
|
||||||
def calc_tiles_with_overlap(
|
def calc_tiles_with_overlap(
|
||||||
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0
|
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int = 0
|
||||||
@ -40,8 +33,10 @@ def calc_tiles_with_overlap(
|
|||||||
num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height)
|
num_tiles_y = math.ceil((image_height - overlap) / non_overlap_per_tile_height)
|
||||||
num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width)
|
num_tiles_x = math.ceil((image_width - overlap) / non_overlap_per_tile_width)
|
||||||
|
|
||||||
# Calculate tile coordinates and overlaps.
|
# tiles[y * num_tiles_x + x] is the tile for the y'th row, x'th column.
|
||||||
tiles: list[Tile] = []
|
tiles: list[Tile] = []
|
||||||
|
|
||||||
|
# Calculate tile coordinates. (Ignore overlap values for now.)
|
||||||
for tile_idx_y in range(num_tiles_y):
|
for tile_idx_y in range(num_tiles_y):
|
||||||
for tile_idx_x in range(num_tiles_x):
|
for tile_idx_x in range(num_tiles_x):
|
||||||
tile = Tile(
|
tile = Tile(
|
||||||
@ -51,12 +46,7 @@ def calc_tiles_with_overlap(
|
|||||||
left=tile_idx_x * non_overlap_per_tile_width,
|
left=tile_idx_x * non_overlap_per_tile_width,
|
||||||
right=tile_idx_x * non_overlap_per_tile_width + tile_width,
|
right=tile_idx_x * non_overlap_per_tile_width + tile_width,
|
||||||
),
|
),
|
||||||
overlap=TBLR(
|
overlap=TBLR(top=0, bottom=0, left=0, right=0),
|
||||||
top=0 if tile_idx_y == 0 else overlap,
|
|
||||||
bottom=overlap,
|
|
||||||
left=0 if tile_idx_x == 0 else overlap,
|
|
||||||
right=overlap,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if tile.coords.bottom > image_height:
|
if tile.coords.bottom > image_height:
|
||||||
@ -64,23 +54,39 @@ def calc_tiles_with_overlap(
|
|||||||
# of the image.
|
# of the image.
|
||||||
tile.coords.bottom = image_height
|
tile.coords.bottom = image_height
|
||||||
tile.coords.top = image_height - tile_height
|
tile.coords.top = image_height - tile_height
|
||||||
tile.overlap.bottom = 0
|
|
||||||
# Note that this could result in a large overlap between this tile and the one above it.
|
|
||||||
top_neighbor_bottom = (tile_idx_y - 1) * non_overlap_per_tile_height + tile_height
|
|
||||||
tile.overlap.top = top_neighbor_bottom - tile.coords.top
|
|
||||||
|
|
||||||
if tile.coords.right > image_width:
|
if tile.coords.right > image_width:
|
||||||
# If this tile would go off the right edge of the image, shift it so that it is aligned with the
|
# If this tile would go off the right edge of the image, shift it so that it is aligned with the
|
||||||
# right edge of the image.
|
# right edge of the image.
|
||||||
tile.coords.right = image_width
|
tile.coords.right = image_width
|
||||||
tile.coords.left = image_width - tile_width
|
tile.coords.left = image_width - tile_width
|
||||||
tile.overlap.right = 0
|
|
||||||
# Note that this could result in a large overlap between this tile and the one to its left.
|
|
||||||
left_neighbor_right = (tile_idx_x - 1) * non_overlap_per_tile_width + tile_width
|
|
||||||
tile.overlap.left = left_neighbor_right - tile.coords.left
|
|
||||||
|
|
||||||
tiles.append(tile)
|
tiles.append(tile)
|
||||||
|
|
||||||
|
def get_tile_or_none(idx_y: int, idx_x: int) -> Union[Tile, None]:
|
||||||
|
if idx_y < 0 or idx_y > num_tiles_y or idx_x < 0 or idx_x > num_tiles_x:
|
||||||
|
return None
|
||||||
|
return tiles[idx_y * num_tiles_x + idx_x]
|
||||||
|
|
||||||
|
# Iterate over tiles again and calculate overlaps.
|
||||||
|
for tile_idx_y in range(num_tiles_y):
|
||||||
|
for tile_idx_x in range(num_tiles_x):
|
||||||
|
cur_tile = get_tile_or_none(tile_idx_y, tile_idx_x)
|
||||||
|
top_neighbor_tile = get_tile_or_none(tile_idx_y - 1, tile_idx_x)
|
||||||
|
left_neighbor_tile = get_tile_or_none(tile_idx_y, tile_idx_x - 1)
|
||||||
|
|
||||||
|
assert cur_tile is not None
|
||||||
|
|
||||||
|
# Update cur_tile top-overlap and corresponding top-neighbor bottom-overlap.
|
||||||
|
if top_neighbor_tile is not None:
|
||||||
|
cur_tile.overlap.top = max(0, top_neighbor_tile.coords.bottom - cur_tile.coords.top)
|
||||||
|
top_neighbor_tile.overlap.bottom = cur_tile.overlap.top
|
||||||
|
|
||||||
|
# Update cur_tile left-overlap and corresponding left-neighbor right-overlap.
|
||||||
|
if left_neighbor_tile is not None:
|
||||||
|
cur_tile.overlap.left = max(0, left_neighbor_tile.coords.right - cur_tile.coords.left)
|
||||||
|
left_neighbor_tile.overlap.right = cur_tile.overlap.left
|
||||||
|
|
||||||
return tiles
|
return tiles
|
||||||
|
|
||||||
|
|
||||||
|
@ -10,11 +10,22 @@ class TBLR(BaseModel):
|
|||||||
left: int
|
left: int
|
||||||
right: int
|
right: int
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return (
|
||||||
|
self.top == other.top
|
||||||
|
and self.bottom == other.bottom
|
||||||
|
and self.left == other.left
|
||||||
|
and self.right == other.right
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Tile(BaseModel):
|
class Tile(BaseModel):
|
||||||
coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.")
|
coords: TBLR = Field(description="The coordinates of this tile relative to its parent image.")
|
||||||
overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.")
|
overlap: TBLR = Field(description="The amount of overlap with adjacent tiles on each side of this tile.")
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.coords == other.coords and self.overlap == other.overlap
|
||||||
|
|
||||||
|
|
||||||
def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None):
|
def paste(dst_image: np.ndarray, src_image: np.ndarray, box: TBLR, mask: Optional[np.ndarray] = None):
|
||||||
"""Paste a source image into a destination image.
|
"""Paste a source image into a destination image.
|
||||||
|
84
tests/backend/tiles/test_tiles.py
Normal file
84
tests/backend/tiles/test_tiles.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap
|
||||||
|
from invokeai.backend.tiles.utils import TBLR, Tile
|
||||||
|
|
||||||
|
####################################
|
||||||
|
# Test calc_tiles_with_overlap(...)
|
||||||
|
####################################
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_with_overlap_single_tile():
|
||||||
|
"""Test calc_tiles_with_overlap() behavior when a single tile covers the image."""
|
||||||
|
tiles = calc_tiles_with_overlap(image_height=512, image_width=1024, tile_height=512, tile_width=1024, overlap=64)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
Tile(coords=TBLR(top=0, bottom=512, left=0, right=1024), overlap=TBLR(top=0, bottom=0, left=0, right=0))
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_with_overlap_evenly_divisible():
|
||||||
|
"""Test calc_tiles_with_overlap() behavior when the image is evenly covered by multiple tiles."""
|
||||||
|
# Parameters chosen so that image is evenly covered by 2 rows, 3 columns of tiles.
|
||||||
|
tiles = calc_tiles_with_overlap(image_height=576, image_width=1600, tile_height=320, tile_width=576, overlap=64)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(coords=TBLR(top=0, bottom=320, left=0, right=576), overlap=TBLR(top=0, bottom=64, left=0, right=64)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=320, left=512, right=1088), overlap=TBLR(top=0, bottom=64, left=64, right=64)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=320, left=1024, right=1600), overlap=TBLR(top=0, bottom=64, left=64, right=0)),
|
||||||
|
# Row 1
|
||||||
|
Tile(coords=TBLR(top=256, bottom=576, left=0, right=576), overlap=TBLR(top=64, bottom=0, left=0, right=64)),
|
||||||
|
Tile(coords=TBLR(top=256, bottom=576, left=512, right=1088), overlap=TBLR(top=64, bottom=0, left=64, right=64)),
|
||||||
|
Tile(coords=TBLR(top=256, bottom=576, left=1024, right=1600), overlap=TBLR(top=64, bottom=0, left=64, right=0)),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
def test_calc_tiles_with_overlap_not_evenly_divisible():
|
||||||
|
"""Test calc_tiles_with_overlap() behavior when the image requires 'uneven' overlaps to achieve proper coverage."""
|
||||||
|
# Parameters chosen so that image is covered by 2 rows and 3 columns of tiles, with uneven overlaps.
|
||||||
|
tiles = calc_tiles_with_overlap(image_height=400, image_width=1200, tile_height=256, tile_width=512, overlap=64)
|
||||||
|
|
||||||
|
expected_tiles = [
|
||||||
|
# Row 0
|
||||||
|
Tile(coords=TBLR(top=0, bottom=256, left=0, right=512), overlap=TBLR(top=0, bottom=112, left=0, right=64)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=256, left=448, right=960), overlap=TBLR(top=0, bottom=112, left=64, right=272)),
|
||||||
|
Tile(coords=TBLR(top=0, bottom=256, left=688, right=1200), overlap=TBLR(top=0, bottom=112, left=272, right=0)),
|
||||||
|
# Row 1
|
||||||
|
Tile(coords=TBLR(top=144, bottom=400, left=0, right=512), overlap=TBLR(top=112, bottom=0, left=0, right=64)),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=448, right=960), overlap=TBLR(top=112, bottom=0, left=64, right=272)
|
||||||
|
),
|
||||||
|
Tile(
|
||||||
|
coords=TBLR(top=144, bottom=400, left=688, right=1200), overlap=TBLR(top=112, bottom=0, left=272, right=0)
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
assert tiles == expected_tiles
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
["image_height", "image_width", "tile_height", "tile_width", "overlap", "raises"],
|
||||||
|
[
|
||||||
|
(128, 128, 128, 128, 127, False), # OK
|
||||||
|
(128, 128, 128, 128, 0, False), # OK
|
||||||
|
(128, 128, 64, 64, 0, False), # OK
|
||||||
|
(128, 128, 129, 128, 0, True), # tile_height exceeds image_height.
|
||||||
|
(128, 128, 128, 129, 0, True), # tile_width exceeds image_width.
|
||||||
|
(128, 128, 64, 128, 64, True), # overlap equals tile_height.
|
||||||
|
(128, 128, 128, 64, 64, True), # overlap equals tile_width.
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_calc_tiles_with_overlap_input_validation(
|
||||||
|
image_height: int, image_width: int, tile_height: int, tile_width: int, overlap: int, raises: bool
|
||||||
|
):
|
||||||
|
"""Test that calc_tiles_with_overlap() raises an exception if the inputs are invalid."""
|
||||||
|
if raises:
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)
|
||||||
|
else:
|
||||||
|
calc_tiles_with_overlap(image_height, image_width, tile_height, tile_width, overlap)
|
Loading…
Reference in New Issue
Block a user