diff --git a/tests/backend/tiles/test_utils.py b/tests/backend/tiles/test_utils.py new file mode 100644 index 0000000000..bbef233ca5 --- /dev/null +++ b/tests/backend/tiles/test_utils.py @@ -0,0 +1,101 @@ +import numpy as np +import pytest + +from invokeai.backend.tiles.utils import TBLR, paste + + +def test_paste_no_mask_success(): + """Test successful paste with mask=None.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + + # Create src_image with a pattern that can be used to validate that it was pasted correctly. + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + src_image[0, :, 0] = 1 # Row of 1s in channel 0. + src_image[:, 0, 1] = 2 # Column of 2s in channel 1. + + # Paste in bottom-center of dst_image. + box = TBLR(top=2, bottom=5, left=1, right=4) + + # Construct expected output image. + expected_output = np.zeros((5, 5, 3), dtype=np.uint8) + expected_output[2, 1:4, 0] = 1 + expected_output[2:5, 1, 1] = 2 + + paste(dst_image=dst_image, src_image=src_image, box=box) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +def test_paste_with_mask_success(): + """Test successful paste with a mask.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + + # Create src_image with a pattern that can be used to validate that it was pasted correctly. + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + src_image[0, :, 0] = 64 # Row of 64s in channel 0. + src_image[:, 0, 1] = 128 # Column of 128s in channel 1. + + # Paste in bottom-center of dst_image. + box = TBLR(top=2, bottom=5, left=1, right=4) + + # Create a mask that blends the top-left corner of 'src_image' at 50%, and ignores the rest of src_image. + mask = np.zeros((3, 3)) + mask[0, 0] = 0.5 + + # Construct expected output image. + expected_output = np.zeros((5, 5, 3), dtype=np.uint8) + expected_output[2, 1, 0] = 32 + expected_output[2, 1, 1] = 64 + + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + np.testing.assert_array_equal(dst_image, expected_output, strict=True) + + +@pytest.mark.parametrize("use_mask", [True, False]) +def test_paste_box_overflows_dst_image(use_mask: bool): + """Test that an exception is raised if 'box' overflows the 'dst_image'.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + mask = None + if use_mask: + mask = np.zeros((3, 3)) + + # Construct box that overflows bottom of dst_image. + top = 3 + left = 0 + box = TBLR(top=top, bottom=top + src_image.shape[0], left=left, right=left + src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + +@pytest.mark.parametrize("use_mask", [True, False]) +def test_paste_src_image_does_not_match_box(use_mask: bool): + """Test that an exception is raised if the 'src_image' shape does not match the 'box' dimensions.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + mask = None + if use_mask: + mask = np.zeros((3, 3)) + + # Construct box that is smaller than src_image. + box = TBLR(top=0, bottom=src_image.shape[0] - 1, left=0, right=src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask) + + +def test_paste_mask_does_not_match_src_image(): + """Test that an exception is raised if the 'mask' shape is different than the 'src_image' shape.""" + dst_image = np.zeros((5, 5, 3), dtype=np.uint8) + src_image = np.zeros((3, 3, 3), dtype=np.uint8) + + # Construct mask that is smaller than the src_image. + mask = np.zeros((src_image.shape[0] - 1, src_image.shape[1])) + + # Construct box that matches src_image shape. + box = TBLR(top=0, bottom=src_image.shape[0], left=0, right=src_image.shape[1]) + + with pytest.raises(ValueError): + paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)