Add unit tests for tile paste(...) util function.

This commit is contained in:
Ryan Dick 2023-11-20 11:53:40 -05:00
parent d742479810
commit caf47dee09

View File

@ -0,0 +1,101 @@
import numpy as np
import pytest
from invokeai.backend.tiles.utils import TBLR, paste
def test_paste_no_mask_success():
"""Test successful paste with mask=None."""
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
# Create src_image with a pattern that can be used to validate that it was pasted correctly.
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
src_image[0, :, 0] = 1 # Row of 1s in channel 0.
src_image[:, 0, 1] = 2 # Column of 2s in channel 1.
# Paste in bottom-center of dst_image.
box = TBLR(top=2, bottom=5, left=1, right=4)
# Construct expected output image.
expected_output = np.zeros((5, 5, 3), dtype=np.uint8)
expected_output[2, 1:4, 0] = 1
expected_output[2:5, 1, 1] = 2
paste(dst_image=dst_image, src_image=src_image, box=box)
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
def test_paste_with_mask_success():
"""Test successful paste with a mask."""
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
# Create src_image with a pattern that can be used to validate that it was pasted correctly.
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
src_image[0, :, 0] = 64 # Row of 64s in channel 0.
src_image[:, 0, 1] = 128 # Column of 128s in channel 1.
# Paste in bottom-center of dst_image.
box = TBLR(top=2, bottom=5, left=1, right=4)
# Create a mask that blends the top-left corner of 'src_image' at 50%, and ignores the rest of src_image.
mask = np.zeros((3, 3))
mask[0, 0] = 0.5
# Construct expected output image.
expected_output = np.zeros((5, 5, 3), dtype=np.uint8)
expected_output[2, 1, 0] = 32
expected_output[2, 1, 1] = 64
paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)
np.testing.assert_array_equal(dst_image, expected_output, strict=True)
@pytest.mark.parametrize("use_mask", [True, False])
def test_paste_box_overflows_dst_image(use_mask: bool):
"""Test that an exception is raised if 'box' overflows the 'dst_image'."""
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
mask = None
if use_mask:
mask = np.zeros((3, 3))
# Construct box that overflows bottom of dst_image.
top = 3
left = 0
box = TBLR(top=top, bottom=top + src_image.shape[0], left=left, right=left + src_image.shape[1])
with pytest.raises(ValueError):
paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)
@pytest.mark.parametrize("use_mask", [True, False])
def test_paste_src_image_does_not_match_box(use_mask: bool):
"""Test that an exception is raised if the 'src_image' shape does not match the 'box' dimensions."""
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
mask = None
if use_mask:
mask = np.zeros((3, 3))
# Construct box that is smaller than src_image.
box = TBLR(top=0, bottom=src_image.shape[0] - 1, left=0, right=src_image.shape[1])
with pytest.raises(ValueError):
paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)
def test_paste_mask_does_not_match_src_image():
"""Test that an exception is raised if the 'mask' shape is different than the 'src_image' shape."""
dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
src_image = np.zeros((3, 3, 3), dtype=np.uint8)
# Construct mask that is smaller than the src_image.
mask = np.zeros((src_image.shape[0] - 1, src_image.shape[1]))
# Construct box that matches src_image shape.
box = TBLR(top=0, bottom=src_image.shape[0], left=0, right=src_image.shape[1])
with pytest.raises(ValueError):
paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)