InvokeAI/invokeai/backend/image_util/util.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

205 lines
6.6 KiB
Python
Raw Normal View History

2023-03-03 06:02:00 +00:00
from math import ceil, floor, sqrt
import cv2
import numpy as np
from PIL import Image
2023-03-03 06:02:00 +00:00
class InitImageResizer:
"""Simple class to create resized copies of an Image while preserving the aspect ratio."""
2023-03-03 06:02:00 +00:00
def __init__(self, Image):
self.image = Image
feat(api): chore: pydantic & fastapi upgrade Upgrade pydantic and fastapi to latest. - pydantic~=2.4.2 - fastapi~=103.2 - fastapi-events~=0.9.1 **Big Changes** There are a number of logic changes needed to support pydantic v2. Most changes are very simple, like using the new methods to serialized and deserialize models, but there are a few more complex changes. **Invocations** The biggest change relates to invocation creation, instantiation and validation. Because pydantic v2 moves all validation logic into the rust pydantic-core, we may no longer directly stick our fingers into the validation pie. Previously, we (ab)used models and fields to allow invocation fields to be optional at instantiation, but required when `invoke()` is called. We directly manipulated the fields and invocation models when calling `invoke()`. With pydantic v2, this is much more involved. Changes to the python wrapper do not propagate down to the rust validation logic - you have to rebuild the model. This causes problem with concurrent access to the invocation classes and is not a free operation. This logic has been totally refactored and we do not need to change the model any more. The details are in `baseinvocation.py`, in the `InputField` function and `BaseInvocation.invoke_internal()` method. In the end, this implementation is cleaner. **Invocation Fields** In pydantic v2, you can no longer directly add or remove fields from a model. Previously, we did this to add the `type` field to invocations. **Invocation Decorators** With pydantic v2, we instead use the imperative `create_model()` API to create a new model with the additional field. This is done in `baseinvocation.py` in the `invocation()` wrapper. A similar technique is used for `invocation_output()`. **Minor Changes** There are a number of minor changes around the pydantic v2 models API. **Protected `model_` Namespace** All models' pydantic-provided methods and attributes are prefixed with `model_` and this is considered a protected namespace. This causes some conflict, because "model" means something to us, and we have a ton of pydantic models with attributes starting with "model_". Forunately, there are no direct conflicts. However, in any pydantic model where we define an attribute or method that starts with "model_", we must tell set the protected namespaces to an empty tuple. ```py class IPAdapterModelField(BaseModel): model_name: str = Field(description="Name of the IP-Adapter model") base_model: BaseModelType = Field(description="Base model") model_config = ConfigDict(protected_namespaces=()) ``` **Model Serialization** Pydantic models no longer have `Model.dict()` or `Model.json()`. Instead, we use `Model.model_dump()` or `Model.model_dump_json()`. **Model Deserialization** Pydantic models no longer have `Model.parse_obj()` or `Model.parse_raw()`, and there are no `parse_raw_as()` or `parse_obj_as()` functions. Instead, you need to create a `TypeAdapter` object to parse python objects or JSON into a model. ```py adapter_graph = TypeAdapter(Graph) deserialized_graph_from_json = adapter_graph.validate_json(graph_json) deserialized_graph_from_dict = adapter_graph.validate_python(graph_dict) ``` **Field Customisation** Pydantic `Field`s no longer accept arbitrary args. Now, you must put all additional arbitrary args in a `json_schema_extra` arg on the field. **Schema Customisation** FastAPI and pydantic schema generation now follows the OpenAPI version 3.1 spec. This necessitates two changes: - Our schema customization logic has been revised - Schema parsing to build node templates has been revised The specific aren't important, but this does present additional surface area for bugs. **Performance Improvements** Pydantic v2 is a full rewrite with a rust backend. This offers a substantial performance improvement (pydantic claims 5x to 50x depending on the task). We'll notice this the most during serialization and deserialization of sessions/graphs, which happens very very often - a couple times per node. I haven't done any benchmarks, but anecdotally, graph execution is much faster. Also, very larges graphs - like with massive iterators - are much, much faster.
2023-09-24 08:11:07 +00:00
def resize(self, width=None, height=None) -> Image.Image:
"""
Return a copy of the image resized to fit within
a box width x height. The aspect ratio is
maintained. If neither width nor height are provided,
then returns a copy of the original image. If one or the other is
provided, then the other will be calculated from the
aspect ratio.
Everything is floored to the nearest multiple of 64 so
that it can be passed to img2img()
"""
2023-03-03 06:02:00 +00:00
im = self.image
2023-03-03 06:02:00 +00:00
ar = im.width / float(im.height)
# Infer missing values from aspect ratio
2023-03-03 06:02:00 +00:00
if not (width or height): # both missing
width = im.width
height = im.height
2023-03-03 06:02:00 +00:00
elif not height: # height missing
height = int(width / ar)
elif not width: # width missing
width = int(height * ar)
2023-03-03 06:02:00 +00:00
w_scale = width / im.width
h_scale = height / im.height
scale = min(w_scale, h_scale)
(rw, rh) = (int(scale * im.width), int(scale * im.height))
2023-03-03 06:02:00 +00:00
# round everything to multiples of 64
width, height, rw, rh = (x - x % 64 for x in (width, height, rw, rh))
# no resize necessary, but return a copy
if im.width == width and im.height == height:
return im.copy()
# otherwise resize the original image so that it fits inside the bounding box
2023-03-03 06:02:00 +00:00
resized_image = self.image.resize((rw, rh), resample=Image.Resampling.LANCZOS)
return resized_image
2023-03-03 06:02:00 +00:00
2022-08-31 04:36:38 +00:00
def make_grid(image_list, rows=None, cols=None):
image_cnt = len(image_list)
if None in (rows, cols):
rows = floor(sqrt(image_cnt)) # try to make it square
cols = ceil(image_cnt / rows)
width = image_list[0].width
height = image_list[0].height
2023-03-03 06:02:00 +00:00
grid_img = Image.new("RGB", (width * cols, height * rows))
2022-08-31 04:36:38 +00:00
i = 0
for r in range(0, rows):
for c in range(0, cols):
if i >= len(image_list):
break
grid_img.paste(image_list[i], (c * width, r * height))
i = i + 1
return grid_img
def pil_to_np(image: Image.Image) -> np.ndarray:
"""Converts a PIL image to a numpy array."""
return np.array(image, dtype=np.uint8)
def np_to_pil(image: np.ndarray) -> Image.Image:
"""Converts a numpy array to a PIL image."""
return Image.fromarray(image)
def pil_to_cv2(image: Image.Image) -> np.ndarray:
"""Converts a PIL image to a CV2 image."""
return cv2.cvtColor(np.array(image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
def cv2_to_pil(image: np.ndarray) -> Image.Image:
"""Converts a CV2 image to a PIL image."""
return Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
def normalize_image_channel_count(image: np.ndarray) -> np.ndarray:
"""Normalizes an image to have 3 channels.
If the image has 1 channel, it will be duplicated 3 times.
If the image has 1 channel, a third empty channel will be added.
If the image has 4 channels, the alpha channel will be used to blend the image with a white background.
Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license).
Args:
image: The input image.
Returns:
The normalized image.
"""
assert image.dtype == np.uint8
if image.ndim == 2:
image = image[:, :, None]
assert image.ndim == 3
_height, _width, channels = image.shape
assert channels == 1 or channels == 3 or channels == 4
if channels == 3:
return image
if channels == 1:
return np.concatenate([image, image, image], axis=2)
if channels == 4:
color = image[:, :, 0:3].astype(np.float32)
alpha = image[:, :, 3:4].astype(np.float32) / 255.0
normalized = color * alpha + 255.0 * (1.0 - alpha)
normalized = normalized.clip(0, 255).astype(np.uint8)
return normalized
raise ValueError("Invalid number of channels.")
def resize_image_to_resolution(input_image: np.ndarray, resolution: int) -> np.ndarray:
"""Resizes an image, fitting it to the given resolution.
Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license).
Args:
input_image: The input image.
resolution: The resolution to fit the image to.
Returns:
The resized image.
"""
h = float(input_image.shape[0])
w = float(input_image.shape[1])
scaling_factor = float(resolution) / min(h, w)
h *= scaling_factor
w *= scaling_factor
h = int(np.round(h / 64.0)) * 64
w = int(np.round(w / 64.0)) * 64
if scaling_factor > 1:
return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_LANCZOS4)
else:
return cv2.resize(input_image, (w, h), interpolation=cv2.INTER_AREA)
def non_maximum_suppression(image: np.ndarray, threshold: int, sigma: float):
"""
Apply non-maximum suppression to an image.
This function is adapted from https://github.com/lllyasviel/ControlNet.
Args:
image: The input image.
threshold: The threshold value for the suppression. Pixels with values greater than this will be set to 255.
sigma: The standard deviation for the Gaussian blur applied to the image.
Returns:
The image after non-maximum suppression.
"""
image = cv2.GaussianBlur(image.astype(np.float32), (0, 0), sigma)
filter_1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
filter_2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
filter_3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
filter_4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
y = np.zeros_like(image)
for f in [filter_1, filter_2, filter_3, filter_4]:
np.putmask(y, cv2.dilate(image, kernel=f) == image, image)
z = np.zeros_like(y, dtype=np.uint8)
z[y > threshold] = 255
return z
def safe_step(x: np.ndarray, step: int = 2) -> np.ndarray:
"""Apply the safe step operation to an array.
I don't fully understand the purpose of this function, but it appears to be normalizing/quantizing the array.
Adapted from https://github.com/huggingface/controlnet_aux (Apache-2.0 license).
Args:
x: The input array.
step: The step value.
Returns:
The array after the safe step operation.
"""
y = x.astype(np.float32) * float(step + 1)
y = y.astype(np.int32).astype(np.float32) / float(step)
return y