import numpy as np
import pytest

from invokeai.backend.tiles.utils import TBLR, paste


def test_paste_no_mask_success():
    """Test successful paste with mask=None."""
    dst_image = np.zeros((5, 5, 3), dtype=np.uint8)

    # Create src_image with a pattern that can be used to validate that it was pasted correctly.
    src_image = np.zeros((3, 3, 3), dtype=np.uint8)
    src_image[0, :, 0] = 1  # Row of 1s in channel 0.
    src_image[:, 0, 1] = 2  # Column of 2s in channel 1.

    # Paste in bottom-center of dst_image.
    box = TBLR(top=2, bottom=5, left=1, right=4)

    # Construct expected output image.
    expected_output = np.zeros((5, 5, 3), dtype=np.uint8)
    expected_output[2, 1:4, 0] = 1
    expected_output[2:5, 1, 1] = 2

    paste(dst_image=dst_image, src_image=src_image, box=box)

    np.testing.assert_array_equal(dst_image, expected_output, strict=True)


def test_paste_with_mask_success():
    """Test successful paste with a mask."""
    dst_image = np.zeros((5, 5, 3), dtype=np.uint8)

    # Create src_image with a pattern that can be used to validate that it was pasted correctly.
    src_image = np.zeros((3, 3, 3), dtype=np.uint8)
    src_image[0, :, 0] = 64  # Row of 64s in channel 0.
    src_image[:, 0, 1] = 128  # Column of 128s in channel 1.

    # Paste in bottom-center of dst_image.
    box = TBLR(top=2, bottom=5, left=1, right=4)

    # Create a mask that blends the top-left corner of 'src_image' at 50%, and ignores the rest of src_image.
    mask = np.zeros((3, 3))
    mask[0, 0] = 0.5

    # Construct expected output image.
    expected_output = np.zeros((5, 5, 3), dtype=np.uint8)
    expected_output[2, 1, 0] = 32
    expected_output[2, 1, 1] = 64

    paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)

    np.testing.assert_array_equal(dst_image, expected_output, strict=True)


@pytest.mark.parametrize("use_mask", [True, False])
def test_paste_box_overflows_dst_image(use_mask: bool):
    """Test that an exception is raised if 'box' overflows the 'dst_image'."""
    dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
    src_image = np.zeros((3, 3, 3), dtype=np.uint8)
    mask = None
    if use_mask:
        mask = np.zeros((3, 3))

    # Construct box that overflows bottom of dst_image.
    top = 3
    left = 0
    box = TBLR(top=top, bottom=top + src_image.shape[0], left=left, right=left + src_image.shape[1])

    with pytest.raises(ValueError):
        paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)


@pytest.mark.parametrize("use_mask", [True, False])
def test_paste_src_image_does_not_match_box(use_mask: bool):
    """Test that an exception is raised if the 'src_image' shape does not match the 'box' dimensions."""
    dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
    src_image = np.zeros((3, 3, 3), dtype=np.uint8)
    mask = None
    if use_mask:
        mask = np.zeros((3, 3))

    # Construct box that is smaller than src_image.
    box = TBLR(top=0, bottom=src_image.shape[0] - 1, left=0, right=src_image.shape[1])

    with pytest.raises(ValueError):
        paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)


def test_paste_mask_does_not_match_src_image():
    """Test that an exception is raised if the 'mask' shape is different than the 'src_image' shape."""
    dst_image = np.zeros((5, 5, 3), dtype=np.uint8)
    src_image = np.zeros((3, 3, 3), dtype=np.uint8)

    # Construct mask that is smaller than the src_image.
    mask = np.zeros((src_image.shape[0] - 1, src_image.shape[1]))

    # Construct box that matches src_image shape.
    box = TBLR(top=0, bottom=src_image.shape[0], left=0, right=src_image.shape[1])

    with pytest.raises(ValueError):
        paste(dst_image=dst_image, src_image=src_image, box=box, mask=mask)