mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
51 lines
1.7 KiB
Python
51 lines
1.7 KiB
Python
import numpy as np
|
|
import pytest
|
|
from PIL import Image
|
|
|
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
|
from invokeai.backend.image_util.util import nms
|
|
|
|
|
|
@pytest.mark.parametrize("num_channels", [1, 2, 3])
|
|
def test_prepare_control_image_num_channels(num_channels):
|
|
"""Test that the `num_channels` parameter is applied correctly in prepare_control_image(...)."""
|
|
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
|
|
pil_image = Image.fromarray(np_image)
|
|
|
|
torch_image = prepare_control_image(
|
|
image=pil_image,
|
|
width=256,
|
|
height=256,
|
|
num_channels=num_channels,
|
|
device="cpu",
|
|
do_classifier_free_guidance=False,
|
|
)
|
|
|
|
assert torch_image.shape == (1, num_channels, 256, 256)
|
|
|
|
|
|
@pytest.mark.parametrize("num_channels", [0, 4])
|
|
def test_prepare_control_image_num_channels_too_large(num_channels):
|
|
"""Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the
|
|
supported range.
|
|
"""
|
|
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
|
|
pil_image = Image.fromarray(np_image)
|
|
|
|
with pytest.raises(ValueError):
|
|
_ = prepare_control_image(
|
|
image=pil_image,
|
|
width=256,
|
|
height=256,
|
|
num_channels=num_channels,
|
|
device="cpu",
|
|
do_classifier_free_guidance=False,
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("threshold,sigma", [(None, 1.0), (1, None)])
|
|
def test_nms_invalid_options(threshold: None | int, sigma: None | float):
|
|
"""Test that an exception is raised in nms(...) if only one of the `threshold` or `sigma` parameters are provided."""
|
|
with pytest.raises(ValueError):
|
|
nms(np.zeros((256, 256, 3), dtype=np.uint8), threshold, sigma)
|