Add num_channels param to prepare_control_image(...).

This commit is contained in:
Ryan Dick
2023-09-21 10:36:36 -04:00
parent 781fa206ea
commit a711b1daa6
5 changed files with 77 additions and 8 deletions

View File

@ -9,7 +9,7 @@ import torch
import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.adapter import FullAdapterXL
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
from diffusers.models.attention_processor import (
AttnProcessor2_0,
LoRAAttnProcessor2_0,
@ -493,6 +493,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
)
t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
@ -516,6 +517,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config.in_channels,
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,

View File

@ -265,20 +265,41 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things
image: Image,
width: int, # should be 8 * latent.shape[3]
height: int, # should be 8 * latent height[2]
# batch_size=1, # currently no batching
# num_images_per_prompt=1, # currently only single image
width: int,
height: int,
num_channels: int = 3,
device="cuda",
dtype=torch.float16,
do_classifier_free_guidance=True,
control_mode="balanced",
resize_mode="just_resize_simple",
):
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
"""Pre-process images for ControlNets or T2I-Adapters.
Args:
image (Image): The PIL image to pre-process.
width (int): The target width in pixels.
height (int): The target height in pixels.
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.
control_mode (str, optional): Defaults to "balanced".
resize_mode (str, optional): Defaults to "just_resize_simple".
Raises:
NotImplementedError: If resize_mode == "crop_resize_simple".
NotImplementedError: If resize_mode == "fill_resize_simple".
ValueError: If `resize_mode` is not recognized.
ValueError: If `num_channels` is out of range.
Returns:
torch.Tensor: The pre-processed input tensor.
"""
if (
resize_mode == "just_resize_simple"
or resize_mode == "crop_resize_simple"
@ -313,6 +334,10 @@ def prepare_control_image(
else:
raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")
if timage.shape[1] < num_channels or num_channels <= 0:
raise ValueError(f"Cannot achieve the target of num_channels={num_channels}.")
timage = timage[:, :num_channels, :, :]
timage = timage.to(device=device, dtype=dtype)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
if do_classifier_free_guidance and not cfg_injection:

0
tests/app/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,42 @@
import numpy as np
import pytest
from PIL import Image
from invokeai.app.util.controlnet_utils import prepare_control_image
@pytest.mark.parametrize("num_channels", [1, 2, 3])
def test_prepare_control_image_num_channels(num_channels):
"""Test that the `num_channels` parameter is applied correctly in prepare_control_image(...)."""
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
pil_image = Image.fromarray(np_image)
torch_image = prepare_control_image(
image=pil_image,
width=256,
height=256,
num_channels=num_channels,
device="cpu",
do_classifier_free_guidance=False,
)
assert torch_image.shape == (1, num_channels, 256, 256)
@pytest.mark.parametrize("num_channels", [0, 4])
def test_prepare_control_image_num_channels_too_large(num_channels):
"""Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the
supported range.
"""
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
pil_image = Image.fromarray(np_image)
with pytest.raises(ValueError):
_ = prepare_control_image(
image=pil_image,
width=256,
height=256,
num_channels=num_channels,
device="cpu",
do_classifier_free_guidance=False,
)