feat(backend): update nms util to make blur/thresholding optional

This commit is contained in:
psychedelicious 2024-04-25 11:28:39 +10:00
parent 5b8f77f990
commit 6b0bf59682
3 changed files with 34 additions and 9 deletions

View File

@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download
from PIL import Image from PIL import Image
from invokeai.backend.image_util.util import ( from invokeai.backend.image_util.util import (
non_maximum_suppression, nms,
normalize_image_channel_count, normalize_image_channel_count,
np_to_pil, np_to_pil,
pil_to_np, pil_to_np,
@ -134,7 +134,7 @@ class HEDProcessor:
detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR) detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)
if scribble: if scribble:
detected_map = non_maximum_suppression(detected_map, 127, 3.0) detected_map = nms(detected_map, 127, 3.0)
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
detected_map[detected_map > 4] = 255 detected_map[detected_map > 4] = 255
detected_map[detected_map < 255] = 0 detected_map[detected_map < 255] = 0

View File

@ -1,4 +1,5 @@
from math import ceil, floor, sqrt from math import ceil, floor, sqrt
from typing import Optional
import cv2 import cv2
import numpy as np import numpy as np
@ -153,10 +154,13 @@ def resize_image_to_resolution(input_image: np.ndarray, resolution: int) -> np.n
return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_AREA) return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_AREA)
def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float): def nms(np_img: np.ndarray, threshold: Optional[int] = None, sigma: Optional[float] = None) -> np.ndarray:
""" """
Apply non-maximum suppression to an image. Apply non-maximum suppression to an image.
If both threshold and sigma are provided, the image will blurred before the suppression and thresholded afterwards,
resulting in a binary output image.
This function is adapted from https://github.com/lllyasviel/ControlNet. This function is adapted from https://github.com/lllyasviel/ControlNet.
Args: Args:
@ -166,23 +170,36 @@ def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float):
Returns: Returns:
The image after non-maximum suppression. The image after non-maximum suppression.
Raises:
ValueError: If only one of threshold and sigma provided.
""" """
image = cv2.GaussianBlur(image.astype(np.float32), (0, 0), sigma) # Raise a value error if only one of threshold and sigma is provided
if (threshold is None) != (sigma is None):
raise ValueError("Both threshold and sigma must be provided if one is provided.")
if sigma is not None and threshold is not None:
# Blurring the image can help to thin out features
np_img = cv2.GaussianBlur(np_img.astype(np.float32), (0, 0), sigma)
filter_1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) filter_1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
filter_2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) filter_2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
filter_3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) filter_3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
filter_4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) filter_4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(image) nms_img = np.zeros_like(np_img)
for f in [filter_1, filter_2, filter_3, filter_4]: for f in [filter_1, filter_2, filter_3, filter_4]:
np.putmask(y, cv2.dilate(image, kernel=f) == image, image) np.putmask(nms_img, cv2.dilate(np_img, kernel=f) == np_img, np_img)
z = np.zeros_like(y, dtype=np.uint8) if sigma is not None and threshold is not None:
z[y > threshold] = 255 # We blurred - now threshold to get a binary image
return z thresholded = np.zeros_like(nms_img, dtype=np.uint8)
thresholded[nms_img > threshold] = 255
return thresholded
return nms_img
def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray: def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray:

View File

@ -3,6 +3,7 @@ import pytest
from PIL import Image from PIL import Image
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.backend.image_util.util import nms
@pytest.mark.parametrize("num_channels", [1, 2, 3]) @pytest.mark.parametrize("num_channels", [1, 2, 3])
@ -40,3 +41,10 @@ def test_prepare_control_image_num_channels_too_large(num_channels):
device="cpu", device="cpu",
do_classifier_free_guidance=False, do_classifier_free_guidance=False,
) )
@pytest.mark.parametrize("threshold,sigma", [(None, 1.0), (1, None)])
def test_nms_invalid_options(threshold: None | int, sigma: None | float):
"""Test that an exception is raised in nms(...) if only one of the `threshold` or `sigma` parameters are provided."""
with pytest.raises(ValueError):
nms(np.zeros((256, 256, 3), dtype=np.uint8), threshold, sigma)