From 85db33bc7ed78f232c70dcbf8aa0d5fc04417662 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 10 Jun 2024 10:52:14 -0400 Subject: [PATCH] Add naive ControlNet support to TiledStableDiffusionRefineInvocation --- .../tiled_stable_diffusion_refine.py | 120 +++++++++++++++--- invokeai/app/util/controlnet_utils.py | 4 +- 2 files changed, 106 insertions(+), 18 deletions(-) diff --git a/invokeai/app/invocations/tiled_stable_diffusion_refine.py b/invokeai/app/invocations/tiled_stable_diffusion_refine.py index 5313bc1c2c..a42d531241 100644 --- a/invokeai/app/invocations/tiled_stable_diffusion_refine.py +++ b/invokeai/app/invocations/tiled_stable_diffusion_refine.py @@ -1,4 +1,7 @@ +from contextlib import ExitStack + import numpy as np +import numpy.typing as npt import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from PIL import Image @@ -17,14 +20,16 @@ from invokeai.app.invocations.fields import ( ) from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation -from invokeai.app.invocations.model import UNetField, VAEField +from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField from invokeai.app.invocations.noise import get_noise from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor +from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image +from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor from invokeai.backend.tiles.tiles import calc_tiles_min_overlap, merge_tiles_with_linear_blending from invokeai.backend.tiles.utils import Tile from invokeai.backend.util.devices import TorchDevice +from invokeai.backend.util.hotfixes import ControlNetModel @invocation( @@ -66,10 +71,6 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): input=Input.Connection, title="UNet", ) - # control: Optional[Union[ControlField, list[ControlField]]] = InputField( - # default=None, - # input=Input.Connection, - # ) cfg_rescale_multiplier: float = InputField( title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier ) @@ -80,6 +81,15 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): vae_fp32: bool = InputField( default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE." ) + # HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly + # don't want to use the image field. Figure out how best to handle this. + # TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty + # much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support + # CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models, + # etc. + control_model: ModelIdentifierField = InputField( + description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel + ) @field_validator("cfg_scale") def ge_one(cls, v: list[float] | float) -> list[float] | float: @@ -112,6 +122,41 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): right = image_tile.coords.right // LATENT_SCALE_FACTOR return latents[..., top:bottom, left:right] + def run_controlnet( + self, + image: Image.Image, + controlnet_model: ControlNetModel, + weight: float, + do_classifier_free_guidance: bool, + width: int, + height: int, + device: torch.device, + dtype: torch.dtype, + control_mode: CONTROLNET_MODE_VALUES = "balanced", + resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple", + ) -> ControlNetData: + control_image = prepare_control_image( + image=image, + do_classifier_free_guidance=do_classifier_free_guidance, + width=width, + height=height, + device=device, + dtype=dtype, + control_mode=control_mode, + resize_mode=resize_mode, + ) + return ControlNetData( + model=controlnet_model, + image_tensor=control_image, + weight=weight, + begin_step_percent=0.0, + end_step_percent=1.0, + control_mode=control_mode, + # Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to + # ControlNetData in case needed in the future. + resize_mode=resize_mode, + ) + @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: # TODO(ryand): Expose the seed parameter. @@ -119,8 +164,6 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): # Load the input image. input_image = context.images.get_pil(self.image.image_name) - input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR) - input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension. # Calculate the tile locations to cover the image. # TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.) @@ -132,12 +175,15 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): min_overlap=128, ) + # Convert the input image to a torch.Tensor. + input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR) + input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension. # Validate our assumptions about the shape of input_image_torch. assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width). assert input_image_torch.shape[:2] == (1, 3) - # Split the input image into tiles. - image_tiles: list[torch.Tensor] = [] + # Split the input image into tiles in torch.Tensor format. + image_tiles_torch: list[torch.Tensor] = [] for tile in tiles: image_tile = input_image_torch[ :, @@ -145,17 +191,30 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right, ] - image_tiles.append(image_tile) + image_tiles_torch.append(image_tile) + + # Split the input image into tiles in numpy format. + # TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work + # with torch.Tensor tiles. + input_image_np = np.array(input_image) + image_tiles_np: list[npt.NDArray[np.uint8]] = [] + for tile in tiles: + image_tile_np = input_image_np[ + tile.coords.top : tile.coords.bottom, + tile.coords.left : tile.coords.right, + :, + ] + image_tiles_np.append(image_tile_np) # VAE-encode each image tile independently. # TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What # about for decoding? vae_info = context.models.load(self.vae.vae) latent_tiles: list[torch.Tensor] = [] - for image_tile in image_tiles: + for image_tile_torch in image_tiles_torch: latent_tiles.append( ImageToLatentsInvocation.vae_encode( - vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile + vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch ) ) @@ -181,7 +240,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): unet_info = context.models.load(self.unet.unet) refined_latent_tiles: list[torch.Tensor] = [] - with unet_info as unet: + with ExitStack() as exit_stack, unet_info as unet: assert isinstance(unet, UNet2DConditionModel) scheduler = get_scheduler( context=context, @@ -206,10 +265,39 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): cfg_rescale_multiplier=self.cfg_rescale_multiplier, ) + # Load the ControlNet model. + # TODO(ryand): Support multiple ControlNet models. + controlnet_model = exit_stack.enter_context(context.models.load(self.control_model)) + assert isinstance(controlnet_model, ControlNetModel) + # Denoise (i.e. "refine") each tile independently. - for latent_tile, noise_tile in zip(latent_tiles, noise_tiles, strict=True): + for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True): assert latent_tile.shape == noise_tile.shape + # Prepare a PIL Image for ControlNet processing. + # TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of + # the tiles. Ideally, the ControlNet code should be able to work with Tensors. + image_tile_pil = Image.fromarray(image_tile_np) + + # Run the ControlNet on the image tile. + height, width, _ = image_tile_np.shape + # The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we + # validate this assumption here. + assert height % LATENT_SCALE_FACTOR == 0 + assert width % LATENT_SCALE_FACTOR == 0 + controlnet_data = self.run_controlnet( + image=image_tile_pil, + controlnet_model=controlnet_model, + weight=1.0, + do_classifier_free_guidance=True, + width=width, + height=height, + device=controlnet_model.device, + dtype=controlnet_model.dtype, + control_mode="balanced", + resize_mode="just_resize_simple", + ) + num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = ( DenoiseLatentsInvocation.init_scheduler( scheduler, @@ -236,7 +324,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): num_inference_steps=num_inference_steps, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, - control_data=None, + control_data=[controlnet_data], ip_adapter_data=None, t2i_adapter_data=None, callback=lambda x: None, diff --git a/invokeai/app/util/controlnet_utils.py b/invokeai/app/util/controlnet_utils.py index fde8d52ee6..f92823a27f 100644 --- a/invokeai/app/util/controlnet_utils.py +++ b/invokeai/app/util/controlnet_utils.py @@ -289,7 +289,7 @@ def prepare_control_image( width: int, height: int, num_channels: int = 3, - device: str = "cuda", + device: str | torch.device = "cuda", dtype: torch.dtype = torch.float16, control_mode: CONTROLNET_MODE_VALUES = "balanced", resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple", @@ -304,7 +304,7 @@ def prepare_control_image( num_channels (int, optional): The target number of image channels. This is achieved by converting the input image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3. - device (str, optional): The target device for the output image. Defaults to "cuda". + device (str | torch.Device, optional): The target device for the output image. Defaults to "cuda". dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16. do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension. Defaults to True.