InvokeAI/tests/app/util/test_controlnet_utils.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

51 lines
1.7 KiB
Python
Raw Permalink Normal View History

import numpy as np
import pytest
from PIL import Image
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.image_util.util import nms
@pytest.mark.parametrize("num_channels", [1, 2, 3])
def test_prepare_control_image_num_channels(num_channels):
"""Test that the `num_channels` parameter is applied correctly in prepare_control_image(...)."""
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
pil_image = Image.fromarray(np_image)
torch_image = prepare_control_image(
image=pil_image,
width=256,
height=256,
num_channels=num_channels,
device="cpu",
do_classifier_free_guidance=False,
)
assert torch_image.shape == (1, num_channels, 256, 256)
@pytest.mark.parametrize("num_channels", [0, 4])
def test_prepare_control_image_num_channels_too_large(num_channels):
"""Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the
supported range.
"""
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
pil_image = Image.fromarray(np_image)
with pytest.raises(ValueError):
_ = prepare_control_image(
image=pil_image,
width=256,
height=256,
num_channels=num_channels,
device="cpu",
do_classifier_free_guidance=False,
)
@pytest.mark.parametrize("threshold,sigma", [(None, 1.0), (1, None)])
def test_nms_invalid_options(threshold: None | int, sigma: None | float):
"""Test that an exception is raised in nms(...) if only one of the `threshold` or `sigma` parameters are provided."""
with pytest.raises(ValueError):
nms(np.zeros((256, 256, 3), dtype=np.uint8), threshold, sigma)