feat: add image utils

These all support controlnet processors.

- `pil_to_cv2`
- `cv2_to_pil`
- `pil_to_np`
- `np_to_pil`
- `normalize_image_channel_count` (a readable version of `HWC3` from the controlnet repo)
- `fit_image_to_resolution` (a readable version of `resize_image` from the controlnet repo)
- `non_maximum_suppression` (a readable version of `nms` from the controlnet repo)
- `safe_step` (a readable version of `safe_step` from the controlnet repo)
This commit is contained in:
psychedelicious 2024-03-21 18:57:33 +11:00 committed by Kent Keirsey
parent 01d8ab04a5
commit ca496f0380

View File

@ -1,5 +1,7 @@
from math import ceil, floor, sqrt from math import ceil, floor, sqrt
import cv2
import numpy as np
from PIL import Image from PIL import Image
@ -69,3 +71,134 @@ def make_grid(image_list, rows=None, cols=None):
i = i + 1 i = i + 1
return grid_img return grid_img
def pil_to_np(image: Image.Image) -> np.ndarray:
"""Converts a PIL image to a numpy array."""
return np.array(image, dtype=np.uint8)
def np_to_pil(image: np.ndarray) -> Image.Image:
"""Converts a numpy array to a PIL image."""
return Image.fromarray(image)
def pil_to_cv2(image: Image.Image) -> np.ndarray:
"""Converts a PIL image to a CV2 image."""
return cv2.cvtColor(np.array(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
def cv2_to_pil(image: np.ndarray) -> Image.Image:
"""Converts a CV2 image to a PIL image."""
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def normalize_image_channel_count(image: np.ndarray) -> np.ndarray:
"""Normalizes an image to have 3 channels.
If the image has 1 channel, it will be duplicated 3 times.
If the image has 1 channel, a third empty channel will be added.
If the image has 4 channels, the alpha channel will be used to blend the image with a white background.
This function is adapted from https://github.com/lllyasviel/ControlNet.
Args:
image: The input image.
Returns:
The normalized image.
"""
assert image.dtype == np.uint8
if image.ndim == 2:
image = image[:, :, None]
assert image.ndim == 3
_height, _width, channels = image.shape
assert channels == 1 or channels == 3 or channels == 4
if channels == 3:
return image
if channels == 1:
return np.concatenate([image, image, image], axis=2)
if channels == 4:
color = image[:, :, 0:3].astype(np.float32)
alpha = image[:, :, 3:4].astype(np.float32) / 255.0
normalized = color * alpha + 255.0 * (1.0 - alpha)
normalized = normalized.clip(0, 255).astype(np.uint8)
return normalized
raise ValueError("Invalid number of channels.")
def fit_image_to_resolution(input_image: np.ndarray, resolution: int) -> np.ndarray:
"""Resizes an image, fitting it to the given resolution.
This function is adapted from https://github.com/lllyasviel/ControlNet.
Args:
input_image: The input image.
resolution: The resolution to fit the image to.
Returns:
The resized image.
"""
h = float(input_image.shape[0])
w = float(input_image.shape[1])
scaling_factor = float(resolution) / min(h, w)
h *= scaling_factor
w *= scaling_factor
h = int(np.round(h / 64.0)) * 64
w = int(np.round(w / 64.0)) * 64
if scaling_factor > 1:
return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_LANCZOS4)
else:
return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_AREA)
def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float):
"""
Apply non-maximum suppression to an image.
This function is adapted from https://github.com/lllyasviel/ControlNet.
Args:
image: The input image.
threshold: The threshold value for the suppression. Pixels with values greater than this will be set to 255.
sigma: The standard deviation for the Gaussian blur applied to the image.
Returns:
The image after non-maximum suppression.
"""
image = cv2.GaussianBlur(image.astype(np.float32), (0, 0), sigma)
filter_1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
filter_2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
filter_3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
filter_4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(image)
for f in [filter_1, filter_2, filter_3, filter_4]:
np.putmask(y, cv2.dilate(image, kernel=f) == image, image)
z = np.zeros_like(y, dtype=np.uint8)
z[y > threshold] = 255
return z
def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray:
"""Apply the safe step operation to an array.
I don't fully understand the purpose of this function, but it appears to be normalizing/quantizing the array.
This function is adapted from https://github.com/lllyasviel/ControlNet.
Args:
x: The input array.
step: The step value.
Returns:
The array after the safe step operation.
"""
y = x.astype(np.float32) * float(step + 1)
y = y.astype(np.int32).astype(np.float32) / float(step)
return y