import numpy as np
import pytest
from PIL import Image

from invokeai.app.util.controlnet_utils import prepare_control_image


@pytest.mark.parametrize("num_channels", [1, 2, 3])
def test_prepare_control_image_num_channels(num_channels):
    """Test that the `num_channels` parameter is applied correctly in prepare_control_image(...)."""
    np_image = np.zeros((256, 256, 3), dtype=np.uint8)
    pil_image = Image.fromarray(np_image)

    torch_image = prepare_control_image(
        image=pil_image,
        width=256,
        height=256,
        num_channels=num_channels,
        device="cpu",
        do_classifier_free_guidance=False,
    )

    assert torch_image.shape == (1, num_channels, 256, 256)


@pytest.mark.parametrize("num_channels", [0, 4])
def test_prepare_control_image_num_channels_too_large(num_channels):
    """Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the
    supported range.
    """
    np_image = np.zeros((256, 256, 3), dtype=np.uint8)
    pil_image = Image.fromarray(np_image)

    with pytest.raises(ValueError):
        _ = prepare_control_image(
            image=pil_image,
            width=256,
            height=256,
            num_channels=num_channels,
            device="cpu",
            do_classifier_free_guidance=False,
        )