mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(backend): update nms util to make blur/thresholding optional
This commit is contained in:
parent
5b8f77f990
commit
6b0bf59682
@ -8,7 +8,7 @@ from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.backend.image_util.util import (
|
||||
non_maximum_suppression,
|
||||
nms,
|
||||
normalize_image_channel_count,
|
||||
np_to_pil,
|
||||
pil_to_np,
|
||||
@ -134,7 +134,7 @@ class HEDProcessor:
|
||||
detected_map = cv2.resize(detected_map, (width, height), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
if scribble:
|
||||
detected_map = non_maximum_suppression(detected_map, 127, 3.0)
|
||||
detected_map = nms(detected_map, 127, 3.0)
|
||||
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
|
||||
detected_map[detected_map > 4] = 255
|
||||
detected_map[detected_map < 255] = 0
|
||||
|
@ -1,4 +1,5 @@
|
||||
from math import ceil, floor, sqrt
|
||||
from typing import Optional
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -153,10 +154,13 @@ def resize_image_to_resolution(input_image: np.ndarray, resolution: int) -> np.n
|
||||
return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_AREA)
|
||||
|
||||
|
||||
def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float):
|
||||
def nms(np_img: np.ndarray, threshold: Optional[int] = None, sigma: Optional[float] = None) -> np.ndarray:
|
||||
"""
|
||||
Apply non-maximum suppression to an image.
|
||||
|
||||
If both threshold and sigma are provided, the image will blurred before the suppression and thresholded afterwards,
|
||||
resulting in a binary output image.
|
||||
|
||||
This function is adapted from https://github.com/lllyasviel/ControlNet.
|
||||
|
||||
Args:
|
||||
@ -166,23 +170,36 @@ def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float):
|
||||
|
||||
Returns:
|
||||
The image after non-maximum suppression.
|
||||
|
||||
Raises:
|
||||
ValueError: If only one of threshold and sigma provided.
|
||||
"""
|
||||
|
||||
image = cv2.GaussianBlur(image.astype(np.float32), (0, 0), sigma)
|
||||
# Raise a value error if only one of threshold and sigma is provided
|
||||
if (threshold is None) != (sigma is None):
|
||||
raise ValueError("Both threshold and sigma must be provided if one is provided.")
|
||||
|
||||
if sigma is not None and threshold is not None:
|
||||
# Blurring the image can help to thin out features
|
||||
np_img = cv2.GaussianBlur(np_img.astype(np.float32), (0, 0), sigma)
|
||||
|
||||
filter_1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
||||
filter_2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
||||
filter_3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
||||
filter_4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
||||
|
||||
y = np.zeros_like(image)
|
||||
nms_img = np.zeros_like(np_img)
|
||||
|
||||
for f in [filter_1, filter_2, filter_3, filter_4]:
|
||||
np.putmask(y, cv2.dilate(image, kernel=f) == image, image)
|
||||
np.putmask(nms_img, cv2.dilate(np_img, kernel=f) == np_img, np_img)
|
||||
|
||||
z = np.zeros_like(y, dtype=np.uint8)
|
||||
z[y > threshold] = 255
|
||||
return z
|
||||
if sigma is not None and threshold is not None:
|
||||
# We blurred - now threshold to get a binary image
|
||||
thresholded = np.zeros_like(nms_img, dtype=np.uint8)
|
||||
thresholded[nms_img > threshold] = 255
|
||||
return thresholded
|
||||
|
||||
return nms_img
|
||||
|
||||
|
||||
def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray:
|
||||
|
@ -3,6 +3,7 @@ import pytest
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.backend.image_util.util import nms
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_channels", [1, 2, 3])
|
||||
@ -40,3 +41,10 @@ def test_prepare_control_image_num_channels_too_large(num_channels):
|
||||
device="cpu",
|
||||
do_classifier_free_guidance=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threshold,sigma", [(None, 1.0), (1, None)])
|
||||
def test_nms_invalid_options(threshold: None | int, sigma: None | float):
|
||||
"""Test that an exception is raised in nms(...) if only one of the `threshold` or `sigma` parameters are provided."""
|
||||
with pytest.raises(ValueError):
|
||||
nms(np.zeros((256, 256, 3), dtype=np.uint8), threshold, sigma)
|
||||
|
Loading…
Reference in New Issue
Block a user