tidy(backend): clean up controlnet_utils

- Use the our adaptation of the HWC3 function with better types
- Extraction some of the util functions, name them better, add comments
- Improve type annotations
- Remove unreachable codepaths
This commit is contained in:
psychedelicious 2024-04-25 13:05:11 +10:00
parent 6b0bf59682
commit 398f37c0ed

View File

@ -1,12 +1,13 @@
from typing import Union from typing import Any, Literal, Union
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from controlnet_aux.util import HWC3
from diffusers.utils import PIL_INTERPOLATION
from einops import rearrange from einops import rearrange
from PIL import Image from PIL import Image
from invokeai.backend.image_util.util import nms, normalize_image_channel_count
CONTROLNET_RESIZE_VALUES = Literal[ CONTROLNET_RESIZE_VALUES = Literal[
"just_resize", "just_resize",
"crop_resize", "crop_resize",
@ -75,17 +76,6 @@ def lvmin_thin(x, prunings=True):
return y return y
def nake_nms(x):
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(x)
for f in [f1, f2, f3, f4]:
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
return y
################################################################################ ################################################################################
# copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI # copied from Mikubill/sd-webui-controlnet external_code.py and modified for InvokeAI
################################################################################ ################################################################################
@ -141,98 +131,122 @@ def pixel_perfect_resolution(
return int(np.round(estimation)) return int(np.round(estimation))
def clone_contiguous(x: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]:
"""Get a memory-contiguous clone of the given numpy array, as a safety measure and to improve computation efficiency."""
return np.ascontiguousarray(x).copy()
def np_img_to_torch(np_img: np.ndarray[Any, Any], device: torch.device) -> torch.Tensor:
"""Convert a numpy image to a PyTorch tensor. The image is normalized to 0-1, rearranged to BCHW format and sent to
the specified device."""
torch_img = torch.from_numpy(np_img)
normalized = torch_img.float() / 255.0
bchw = rearrange(normalized, "h w c -> 1 c h w")
on_device = bchw.to(device)
return on_device.clone()
def heuristic_resize(np_img: np.ndarray[Any, Any], size: tuple[int, int]) -> np.ndarray[Any, Any]:
"""Resizes an image using a heuristic to choose the best resizing strategy.
- If the image appears to be an edge map, special handling will be applied to ensure the edges are not distorted.
- Single-pixel edge maps use NMS and thinning to keep the edges as single-pixel lines.
- Low-color-count images are resized with nearest-neighbor to preserve color information (for e.g. segmentation maps).
- The alpha channel is handled separately to ensure it is resized correctly.
Args:
np_img (np.ndarray): The input image.
size (tuple[int, int]): The target size for the image.
Returns:
np.ndarray: The resized image.
Adapted from https://github.com/Mikubill/sd-webui-controlnet.
"""
# Return early if the image is already at the requested size
if np_img.shape[0] == size[1] and np_img.shape[1] == size[0]:
return np_img
# If the image has an alpha channel, separate it for special handling later.
inpaint_mask = None
if np_img.ndim == 3 and np_img.shape[2] == 4:
inpaint_mask = np_img[:, :, 3]
np_img = np_img[:, :, 0:3]
new_size_is_smaller = (size[0] * size[1]) < (np_img.shape[0] * np_img.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (np_img.shape[0] * np_img.shape[1])
unique_color_count = np.unique(np_img.reshape(-1, np_img.shape[2]), axis=0).shape[0]
is_one_pixel_edge = False
is_binary = False
if unique_color_count == 2:
# If the image has only two colors, it is likely binary. Check if the image has one-pixel edges.
is_binary = np.min(np_img) < 16 and np.max(np_img) > 240
if is_binary:
eroded = cv2.erode(np_img, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
dilated = cv2.dilate(eroded, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
one_pixel_edge_count = np.where(dilated < np_img)[0].shape[0]
all_edge_count = np.where(np_img > 127)[0].shape[0]
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
if 2 < unique_color_count < 200:
# With a low color count, we assume this is a map where exact colors are important. Near-neighbor preserves
# the colors as needed.
interpolation = cv2.INTER_NEAREST
elif new_size_is_smaller:
# This works best for downscaling
interpolation = cv2.INTER_AREA
else:
# Fall back for other cases
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
# This may be further transformed depending on the binary nature of the image.
resized = cv2.resize(np_img, size, interpolation=interpolation)
if inpaint_mask is not None:
# Resize the inpaint mask to match the resized image using the same interpolation method.
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
# If the image is binary, we will perform some additional processing to ensure the edges are preserved.
if is_binary:
resized = np.mean(resized.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
if is_one_pixel_edge:
# Use NMS and thinning to keep the edges as single-pixel lines.
resized = nms(resized)
_, resized = cv2.threshold(resized, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
resized = lvmin_thin(resized, prunings=new_size_is_bigger)
else:
_, resized = cv2.threshold(resized, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
resized = np.stack([resized] * 3, axis=2)
# Restore the alpha channel if it was present.
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
resized = np.concatenate([resized, inpaint_mask], axis=2)
return resized
########################################################################### ###########################################################################
# Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet # Copied from detectmap_proc method in scripts/detectmap_proc.py in Mikubill/sd-webui-controlnet
# modified for InvokeAI # modified for InvokeAI
########################################################################### ###########################################################################
# def detectmap_proc(detected_map, module, resize_mode, h, w): def np_img_resize(
def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: torch.device = torch.device("cpu")): np_img: np.ndarray,
# if 'inpaint' in module: resize_mode: CONTROLNET_RESIZE_VALUES,
# np_img = np_img.astype(np.float32) h: int,
# else: w: int,
# np_img = HWC3(np_img) device: torch.device = torch.device("cpu"),
np_img = HWC3(np_img) ) -> tuple[torch.Tensor, np.ndarray[Any, Any]]:
np_img = normalize_image_channel_count(np_img)
def safe_numpy(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = y.copy()
y = np.ascontiguousarray(y)
y = y.copy()
return y
def get_pytorch_control(x):
# A very safe method to make sure that Apple/Mac works
y = x
# below is very boring but do not change these. If you change these Apple or Mac may fail.
y = torch.from_numpy(y)
y = y.float() / 255.0
y = rearrange(y, "h w c -> 1 c h w")
y = y.clone()
# y = y.to(devices.get_device_for("controlnet"))
y = y.to(device)
y = y.clone()
return y
def high_quality_resize(x: np.ndarray, size):
# Written by lvmin
# Super high-quality control map up-scaling, considering binary, seg, and one-pixel edges
inpaint_mask = None
if x.ndim == 3 and x.shape[2] == 4:
inpaint_mask = x[:, :, 3]
x = x[:, :, 0:3]
new_size_is_smaller = (size[0] * size[1]) < (x.shape[0] * x.shape[1])
new_size_is_bigger = (size[0] * size[1]) > (x.shape[0] * x.shape[1])
unique_color_count = np.unique(x.reshape(-1, x.shape[2]), axis=0).shape[0]
is_one_pixel_edge = False
is_binary = False
if unique_color_count == 2:
is_binary = np.min(x) < 16 and np.max(x) > 240
if is_binary:
xc = x
xc = cv2.erode(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
xc = cv2.dilate(xc, np.ones(shape=(3, 3), dtype=np.uint8), iterations=1)
one_pixel_edge_count = np.where(xc < x)[0].shape[0]
all_edge_count = np.where(x > 127)[0].shape[0]
is_one_pixel_edge = one_pixel_edge_count * 2 > all_edge_count
if 2 < unique_color_count < 200:
interpolation = cv2.INTER_NEAREST
elif new_size_is_smaller:
interpolation = cv2.INTER_AREA
else:
interpolation = cv2.INTER_CUBIC # Must be CUBIC because we now use nms. NEVER CHANGE THIS
y = cv2.resize(x, size, interpolation=interpolation)
if inpaint_mask is not None:
inpaint_mask = cv2.resize(inpaint_mask, size, interpolation=interpolation)
if is_binary:
y = np.mean(y.astype(np.float32), axis=2).clip(0, 255).astype(np.uint8)
if is_one_pixel_edge:
y = nake_nms(y)
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = lvmin_thin(y, prunings=new_size_is_bigger)
else:
_, y = cv2.threshold(y, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
y = np.stack([y] * 3, axis=2)
if inpaint_mask is not None:
inpaint_mask = (inpaint_mask > 127).astype(np.float32) * 255.0
inpaint_mask = inpaint_mask[:, :, None].clip(0, 255).astype(np.uint8)
y = np.concatenate([y, inpaint_mask], axis=2)
return y
# if resize_mode == external_code.ResizeMode.RESIZE:
if resize_mode == "just_resize": # RESIZE if resize_mode == "just_resize": # RESIZE
np_img = high_quality_resize(np_img, (w, h)) np_img = heuristic_resize(np_img, (w, h))
np_img = safe_numpy(np_img) np_img = clone_contiguous(np_img)
return get_pytorch_control(np_img), np_img return np_img_to_torch(np_img, device), np_img
old_h, old_w, _ = np_img.shape old_h, old_w, _ = np_img.shape
old_w = float(old_w) old_w = float(old_w)
@ -243,7 +257,6 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
def safeint(x: Union[int, float]) -> int: def safeint(x: Union[int, float]) -> int:
return int(np.round(x)) return int(np.round(x))
# if resize_mode == external_code.ResizeMode.OUTER_FIT:
if resize_mode == "fill_resize": # OUTER_FIT if resize_mode == "fill_resize": # OUTER_FIT
k = min(k0, k1) k = min(k0, k1)
borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0) borders = np.concatenate([np_img[0, :, :], np_img[-1, :, :], np_img[:, 0, :], np_img[:, -1, :]], axis=0)
@ -252,23 +265,23 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
# Inpaint hijack # Inpaint hijack
high_quality_border_color[3] = 255 high_quality_border_color[3] = 255
high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1]) high_quality_background = np.tile(high_quality_border_color[None, None], [h, w, 1])
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k))) np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape new_h, new_w, _ = np_img.shape
pad_h = max(0, (h - new_h) // 2) pad_h = max(0, (h - new_h) // 2)
pad_w = max(0, (w - new_w) // 2) pad_w = max(0, (w - new_w) // 2)
high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img high_quality_background[pad_h : pad_h + new_h, pad_w : pad_w + new_w] = np_img
np_img = high_quality_background np_img = high_quality_background
np_img = safe_numpy(np_img) np_img = clone_contiguous(np_img)
return get_pytorch_control(np_img), np_img return np_img_to_torch(np_img, device), np_img
else: # resize_mode == "crop_resize" (INNER_FIT) else: # resize_mode == "crop_resize" (INNER_FIT)
k = max(k0, k1) k = max(k0, k1)
np_img = high_quality_resize(np_img, (safeint(old_w * k), safeint(old_h * k))) np_img = heuristic_resize(np_img, (safeint(old_w * k), safeint(old_h * k)))
new_h, new_w, _ = np_img.shape new_h, new_w, _ = np_img.shape
pad_h = max(0, (new_h - h) // 2) pad_h = max(0, (new_h - h) // 2)
pad_w = max(0, (new_w - w) // 2) pad_w = max(0, (new_w - w) // 2)
np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w] np_img = np_img[pad_h : pad_h + h, pad_w : pad_w + w]
np_img = safe_numpy(np_img) np_img = clone_contiguous(np_img)
return get_pytorch_control(np_img), np_img return np_img_to_torch(np_img, device), np_img
def prepare_control_image( def prepare_control_image(
@ -276,12 +289,12 @@ def prepare_control_image(
width: int, width: int,
height: int, height: int,
num_channels: int = 3, num_channels: int = 3,
device="cuda", device: str = "cuda",
dtype=torch.float16, dtype: torch.dtype = torch.float16,
do_classifier_free_guidance=True, control_mode: CONTROLNET_MODE_VALUES = "balanced",
control_mode="balanced", resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
resize_mode="just_resize_simple", do_classifier_free_guidance: bool = True,
): ) -> torch.Tensor:
"""Pre-process images for ControlNets or T2I-Adapters. """Pre-process images for ControlNets or T2I-Adapters.
Args: Args:
@ -299,26 +312,15 @@ def prepare_control_image(
resize_mode (str, optional): Defaults to "just_resize_simple". resize_mode (str, optional): Defaults to "just_resize_simple".
Raises: Raises:
NotImplementedError: If resize_mode == "crop_resize_simple".
NotImplementedError: If resize_mode == "fill_resize_simple".
ValueError: If `resize_mode` is not recognized. ValueError: If `resize_mode` is not recognized.
ValueError: If `num_channels` is out of range. ValueError: If `num_channels` is out of range.
Returns: Returns:
torch.Tensor: The pre-processed input tensor. torch.Tensor: The pre-processed input tensor.
""" """
if ( if resize_mode == "just_resize_simple":
resize_mode == "just_resize_simple"
or resize_mode == "crop_resize_simple"
or resize_mode == "fill_resize_simple"
):
image = image.convert("RGB") image = image.convert("RGB")
if resize_mode == "just_resize_simple": image = image.resize((width, height), resample=Image.LANCZOS)
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif resize_mode == "crop_resize_simple":
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
elif resize_mode == "fill_resize_simple":
raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
nimage = np.array(image) nimage = np.array(image)
nimage = nimage[None, :] nimage = nimage[None, :]
nimage = np.concatenate([nimage], axis=0) nimage = np.concatenate([nimage], axis=0)
@ -335,8 +337,7 @@ def prepare_control_image(
resize_mode=resize_mode, resize_mode=resize_mode,
h=height, h=height,
w=width, w=width,
# device=torch.device('cpu') device=torch.device(device),
device=device,
) )
else: else:
raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.") raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")