InvokeAI/tests/app/util/test_controlnet_utils.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

43 lines
1.3 KiB
Python
Raw Normal View History

import numpy as np
import pytest
from PIL import Image
from invokeai.app.util.controlnet_utils import prepare_control_image
@pytest.mark.parametrize("num_channels", [1, 2, 3])
def test_prepare_control_image_num_channels(num_channels):
"""Test that the `num_channels` parameter is applied correctly in prepare_control_image(...)."""
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
pil_image = Image.fromarray(np_image)
torch_image = prepare_control_image(
image=pil_image,
width=256,
height=256,
num_channels=num_channels,
device="cpu",
do_classifier_free_guidance=False,
)
assert torch_image.shape == (1, num_channels, 256, 256)
@pytest.mark.parametrize("num_channels", [0, 4])
def test_prepare_control_image_num_channels_too_large(num_channels):
"""Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the
supported range.
"""
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
pil_image = Image.fromarray(np_image)
with pytest.raises(ValueError):
_ = prepare_control_image(
image=pil_image,
width=256,
height=256,
num_channels=num_channels,
device="cpu",
do_classifier_free_guidance=False,
)