mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add num_channels param to prepare_control_image(...).
This commit is contained in:
@ -9,7 +9,7 @@ import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.adapter import FullAdapterXL
|
||||
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
@ -493,6 +493,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
|
||||
)
|
||||
|
||||
t2i_adapter_model: T2IAdapter
|
||||
with t2i_adapter_model_info as t2i_adapter_model:
|
||||
total_downscale_factor = t2i_adapter_model.total_downscale_factor
|
||||
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
|
||||
@ -516,6 +517,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=False,
|
||||
width=t2i_input_width,
|
||||
height=t2i_input_height,
|
||||
num_channels=t2i_adapter_model.config.in_channels,
|
||||
device=t2i_adapter_model.device,
|
||||
dtype=t2i_adapter_model.dtype,
|
||||
resize_mode=t2i_adapter_field.resize_mode,
|
||||
|
@ -265,20 +265,41 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
|
||||
|
||||
|
||||
def prepare_control_image(
|
||||
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
|
||||
# but now should be able to assume that image is a single PIL.Image, which simplifies things
|
||||
image: Image,
|
||||
width: int, # should be 8 * latent.shape[3]
|
||||
height: int, # should be 8 * latent height[2]
|
||||
# batch_size=1, # currently no batching
|
||||
# num_images_per_prompt=1, # currently only single image
|
||||
width: int,
|
||||
height: int,
|
||||
num_channels: int = 3,
|
||||
device="cuda",
|
||||
dtype=torch.float16,
|
||||
do_classifier_free_guidance=True,
|
||||
control_mode="balanced",
|
||||
resize_mode="just_resize_simple",
|
||||
):
|
||||
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out
|
||||
"""Pre-process images for ControlNets or T2I-Adapters.
|
||||
|
||||
Args:
|
||||
image (Image): The PIL image to pre-process.
|
||||
width (int): The target width in pixels.
|
||||
height (int): The target height in pixels.
|
||||
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
|
||||
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
|
||||
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
|
||||
device (str, optional): The target device for the output image. Defaults to "cuda".
|
||||
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
|
||||
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
|
||||
Defaults to True.
|
||||
control_mode (str, optional): Defaults to "balanced".
|
||||
resize_mode (str, optional): Defaults to "just_resize_simple".
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If resize_mode == "crop_resize_simple".
|
||||
NotImplementedError: If resize_mode == "fill_resize_simple".
|
||||
ValueError: If `resize_mode` is not recognized.
|
||||
ValueError: If `num_channels` is out of range.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The pre-processed input tensor.
|
||||
"""
|
||||
if (
|
||||
resize_mode == "just_resize_simple"
|
||||
or resize_mode == "crop_resize_simple"
|
||||
@ -313,6 +334,10 @@ def prepare_control_image(
|
||||
else:
|
||||
raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")
|
||||
|
||||
if timage.shape[1] < num_channels or num_channels <= 0:
|
||||
raise ValueError(f"Cannot achieve the target of num_channels={num_channels}.")
|
||||
timage = timage[:, :num_channels, :, :]
|
||||
|
||||
timage = timage.to(device=device, dtype=dtype)
|
||||
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"
|
||||
if do_classifier_free_guidance and not cfg_injection:
|
||||
|
0
tests/app/__init__.py
Normal file
0
tests/app/__init__.py
Normal file
0
tests/app/util/__init__.py
Normal file
0
tests/app/util/__init__.py
Normal file
42
tests/app/util/test_controlnet_utils.py
Normal file
42
tests/app/util/test_controlnet_utils.py
Normal file
@ -0,0 +1,42 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_channels", [1, 2, 3])
|
||||
def test_prepare_control_image_num_channels(num_channels):
|
||||
"""Test that the `num_channels` parameter is applied correctly in prepare_control_image(...)."""
|
||||
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
torch_image = prepare_control_image(
|
||||
image=pil_image,
|
||||
width=256,
|
||||
height=256,
|
||||
num_channels=num_channels,
|
||||
device="cpu",
|
||||
do_classifier_free_guidance=False,
|
||||
)
|
||||
|
||||
assert torch_image.shape == (1, num_channels, 256, 256)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_channels", [0, 4])
|
||||
def test_prepare_control_image_num_channels_too_large(num_channels):
|
||||
"""Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the
|
||||
supported range.
|
||||
"""
|
||||
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
|
||||
pil_image = Image.fromarray(np_image)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = prepare_control_image(
|
||||
image=pil_image,
|
||||
width=256,
|
||||
height=256,
|
||||
num_channels=num_channels,
|
||||
device="cpu",
|
||||
do_classifier_free_guidance=False,
|
||||
)
|
Reference in New Issue
Block a user