diff --git a/invokeai/backend/image_util/hed.py b/invokeai/backend/image_util/hed.py index 378e3b96e9..97706df8b9 100644 --- a/invokeai/backend/image_util/hed.py +++ b/invokeai/backend/image_util/hed.py @@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download from PIL import Image from invokeai.backend.image_util.util import ( - non_maximum_suppression, + nms, normalize_image_channel_count, np_to_pil, pil_to_np, @@ -134,7 +134,7 @@ class HEDProcessor: detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR) if scribble: - detected_map = non_maximum_suppression(detected_map, 127, 3.0) + detected_map = nms(detected_map, 127, 3.0) detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) detected_map[detected_map > 4] = 255 detected_map[detected_map < 255] = 0 diff --git a/invokeai/backend/image_util/util.py b/invokeai/backend/image_util/util.py index 7cfe0ad1a5..f704f068e3 100644 --- a/invokeai/backend/image_util/util.py +++ b/invokeai/backend/image_util/util.py @@ -1,4 +1,5 @@ from math import ceil, floor, sqrt +from typing import Optional import cv2 import numpy as np @@ -153,10 +154,13 @@ def resize_image_to_resolution(input_image: np.ndarray, resolution: int) -> np.n return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_AREA) -def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float): +def nms(np_img: np.ndarray, threshold: Optional[int] = None, sigma: Optional[float] = None) -> np.ndarray: """ Apply non-maximum suppression to an image. + If both threshold and sigma are provided, the image will blurred before the suppression and thresholded afterwards, + resulting in a binary output image. + This function is adapted from https://github.com/lllyasviel/ControlNet. Args: @@ -166,23 +170,36 @@ def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float): Returns: The image after non-maximum suppression. + + Raises: + ValueError: If only one of threshold and sigma provided. """ - image = cv2.GaussianBlur(image.astype(np.float32), (0, 0), sigma) + # Raise a value error if only one of threshold and sigma is provided + if (threshold is None) != (sigma is None): + raise ValueError("Both threshold and sigma must be provided if one is provided.") + + if sigma is not None and threshold is not None: + # Blurring the image can help to thin out features + np_img = cv2.GaussianBlur(np_img.astype(np.float32), (0, 0), sigma) filter_1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) filter_2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) filter_3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) filter_4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) - y = np.zeros_like(image) + nms_img = np.zeros_like(np_img) for f in [filter_1, filter_2, filter_3, filter_4]: - np.putmask(y, cv2.dilate(image, kernel=f) == image, image) + np.putmask(nms_img, cv2.dilate(np_img, kernel=f) == np_img, np_img) - z = np.zeros_like(y, dtype=np.uint8) - z[y > threshold] = 255 - return z + if sigma is not None and threshold is not None: + # We blurred - now threshold to get a binary image + thresholded = np.zeros_like(nms_img, dtype=np.uint8) + thresholded[nms_img > threshold] = 255 + return thresholded + + return nms_img def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray: diff --git a/tests/app/util/test_controlnet_utils.py b/tests/app/util/test_controlnet_utils.py index 21662cce8d..9806fe7806 100644 --- a/tests/app/util/test_controlnet_utils.py +++ b/tests/app/util/test_controlnet_utils.py @@ -3,6 +3,7 @@ import pytest from PIL import Image from invokeai.app.util.controlnet_utils import prepare_control_image +from invokeai.backend.image_util.util import nms @pytest.mark.parametrize("num_channels", [1, 2, 3]) @@ -40,3 +41,10 @@ def test_prepare_control_image_num_channels_too_large(num_channels): device="cpu", do_classifier_free_guidance=False, ) + + +@pytest.mark.parametrize("threshold,sigma", [(None, 1.0), (1, None)]) +def test_nms_invalid_options(threshold: None | int, sigma: None | float): + """Test that an exception is raised in nms(...) if only one of the `threshold` or `sigma` parameters are provided.""" + with pytest.raises(ValueError): + nms(np.zeros((256, 256, 3), dtype=np.uint8), threshold, sigma)