From a711b1daa6e74df402ef7ea54f386703e5543d35 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 21 Sep 2023 10:36:36 -0400 Subject: [PATCH] Add num_channels param to prepare_control_image(...). --- invokeai/app/invocations/latent.py | 4 ++- invokeai/app/util/controlnet_utils.py | 39 ++++++++++++++++++----- tests/app/__init__.py | 0 tests/app/util/__init__.py | 0 tests/app/util/test_controlnet_utils.py | 42 +++++++++++++++++++++++++ 5 files changed, 77 insertions(+), 8 deletions(-) create mode 100644 tests/app/__init__.py create mode 100644 tests/app/util/__init__.py create mode 100644 tests/app/util/test_controlnet_utils.py diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index 746490ed7a..722d135bd7 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -9,7 +9,7 @@ import torch import torchvision.transforms as T from diffusers.image_processor import VaeImageProcessor from diffusers.models import UNet2DConditionModel -from diffusers.models.adapter import FullAdapterXL +from diffusers.models.adapter import FullAdapterXL, T2IAdapter from diffusers.models.attention_processor import ( AttnProcessor2_0, LoRAAttnProcessor2_0, @@ -493,6 +493,7 @@ class DenoiseLatentsInvocation(BaseInvocation): f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'." ) + t2i_adapter_model: T2IAdapter with t2i_adapter_model_info as t2i_adapter_model: total_downscale_factor = t2i_adapter_model.total_downscale_factor if isinstance(t2i_adapter_model.adapter, FullAdapterXL): @@ -516,6 +517,7 @@ class DenoiseLatentsInvocation(BaseInvocation): do_classifier_free_guidance=False, width=t2i_input_width, height=t2i_input_height, + num_channels=t2i_adapter_model.config.in_channels, device=t2i_adapter_model.device, dtype=t2i_adapter_model.dtype, resize_mode=t2i_adapter_field.resize_mode, diff --git a/invokeai/app/util/controlnet_utils.py b/invokeai/app/util/controlnet_utils.py index 9cda346e56..e6f34a4c44 100644 --- a/invokeai/app/util/controlnet_utils.py +++ b/invokeai/app/util/controlnet_utils.py @@ -265,20 +265,41 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device: def prepare_control_image( - # image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]] - # but now should be able to assume that image is a single PIL.Image, which simplifies things image: Image, - width: int, # should be 8 * latent.shape[3] - height: int, # should be 8 * latent height[2] - # batch_size=1, # currently no batching - # num_images_per_prompt=1, # currently only single image + width: int, + height: int, + num_channels: int = 3, device="cuda", dtype=torch.float16, do_classifier_free_guidance=True, control_mode="balanced", resize_mode="just_resize_simple", ): - # FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out + """Pre-process images for ControlNets or T2I-Adapters. + + Args: + image (Image): The PIL image to pre-process. + width (int): The target width in pixels. + height (int): The target height in pixels. + num_channels (int, optional): The target number of image channels. This is achieved by converting the input + image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a + RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3. + device (str, optional): The target device for the output image. Defaults to "cuda". + dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16. + do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension. + Defaults to True. + control_mode (str, optional): Defaults to "balanced". + resize_mode (str, optional): Defaults to "just_resize_simple". + + Raises: + NotImplementedError: If resize_mode == "crop_resize_simple". + NotImplementedError: If resize_mode == "fill_resize_simple". + ValueError: If `resize_mode` is not recognized. + ValueError: If `num_channels` is out of range. + + Returns: + torch.Tensor: The pre-processed input tensor. + """ if ( resize_mode == "just_resize_simple" or resize_mode == "crop_resize_simple" @@ -313,6 +334,10 @@ def prepare_control_image( else: raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.") + if timage.shape[1] < num_channels or num_channels <= 0: + raise ValueError(f"Cannot achieve the target of num_channels={num_channels}.") + timage = timage[:, :num_channels, :, :] + timage = timage.to(device=device, dtype=dtype) cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" if do_classifier_free_guidance and not cfg_injection: diff --git a/tests/app/__init__.py b/tests/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/app/util/__init__.py b/tests/app/util/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/app/util/test_controlnet_utils.py b/tests/app/util/test_controlnet_utils.py new file mode 100644 index 0000000000..21662cce8d --- /dev/null +++ b/tests/app/util/test_controlnet_utils.py @@ -0,0 +1,42 @@ +import numpy as np +import pytest +from PIL import Image + +from invokeai.app.util.controlnet_utils import prepare_control_image + + +@pytest.mark.parametrize("num_channels", [1, 2, 3]) +def test_prepare_control_image_num_channels(num_channels): + """Test that the `num_channels` parameter is applied correctly in prepare_control_image(...).""" + np_image = np.zeros((256, 256, 3), dtype=np.uint8) + pil_image = Image.fromarray(np_image) + + torch_image = prepare_control_image( + image=pil_image, + width=256, + height=256, + num_channels=num_channels, + device="cpu", + do_classifier_free_guidance=False, + ) + + assert torch_image.shape == (1, num_channels, 256, 256) + + +@pytest.mark.parametrize("num_channels", [0, 4]) +def test_prepare_control_image_num_channels_too_large(num_channels): + """Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the + supported range. + """ + np_image = np.zeros((256, 256, 3), dtype=np.uint8) + pil_image = Image.fromarray(np_image) + + with pytest.raises(ValueError): + _ = prepare_control_image( + image=pil_image, + width=256, + height=256, + num_channels=num_channels, + device="cpu", + do_classifier_free_guidance=False, + )