From 8379feeb8a9b3114b9a17c243a9116c64f3e2387 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Wed, 26 Jun 2024 20:39:29 -0400 Subject: [PATCH] Refactor TiledStableDiffusionRefineInvocation to more closely mirror TiledMultiDiffusionDenoiseLatents. The biggest improvement is in the handling of the ControlNets - global ControlNet info can now be passed in and it is tiled within the node. --- .../tiled_stable_diffusion_refine.py | 274 ++++++++---------- 1 file changed, 115 insertions(+), 159 deletions(-) diff --git a/invokeai/app/invocations/tiled_stable_diffusion_refine.py b/invokeai/app/invocations/tiled_stable_diffusion_refine.py index c4075407a7..968cae9313 100644 --- a/invokeai/app/invocations/tiled_stable_diffusion_refine.py +++ b/invokeai/app/invocations/tiled_stable_diffusion_refine.py @@ -2,14 +2,14 @@ from contextlib import ExitStack from typing import Iterator, Tuple import numpy as np -import numpy.typing as npt import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from PIL import Image from pydantic import field_validator -from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation +from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES +from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler from invokeai.app.invocations.fields import ( ConditioningField, @@ -17,22 +17,24 @@ from invokeai.app.invocations.fields import ( ImageField, Input, InputField, + LatentsField, UIType, ) from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation -from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField -from invokeai.app.invocations.noise import get_noise +from invokeai.app.invocations.model import UNetField, VAEField from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.invocations.tiled_multi_diffusion_denoise_latents import crop_controlnet_data from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor -from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending -from invokeai.backend.tiles.utils import Tile +from invokeai.backend.tiles.tiles import ( + calc_tiles_min_overlap, + merge_tiles_with_linear_blending, +) +from invokeai.backend.tiles.utils import TBLR, Tile from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.util.hotfixes import ControlNetModel @invocation( @@ -40,6 +42,7 @@ from invokeai.backend.util.hotfixes import ControlNetModel title="Tiled Stable Diffusion Refine", tags=["upscale", "denoise"], category="latents", + classification=Classification.Beta, version="1.0.0", ) class TiledStableDiffusionRefineInvocation(BaseInvocation): @@ -55,13 +58,21 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): negative_conditioning: ConditioningField = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection ) - # TODO(ryand): Add multiple-of validation. - tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.") - tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.") + noise: LatentsField = InputField( + description=FieldDescriptions.noise, + input=Input.Connection, + ) + tile_height: int = InputField( + default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space." + ) + tile_width: int = InputField( + default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space." + ) tile_overlap: int = InputField( - default=16, + default=32, + multiple_of=LATENT_SCALE_FACTOR, gt=0, - description="Target overlap between adjacent tiles (the last row/column may overlap more than this).", + description="Target overlap between adjacent tiles in image space.", ) steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps) cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") @@ -92,16 +103,10 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): vae_fp32: bool = InputField( default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE." ) - # HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly - # don't want to use the image field. Figure out how best to handle this. - # TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty - # much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support - # CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models, - # etc. - control_model: ModelIdentifierField = InputField( - description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel + control: ControlField | list[ControlField] | None = InputField( + default=None, + input=Input.Connection, ) - control_weight: float = InputField(default=0.6) @field_validator("cfg_scale") def ge_one(cls, v: list[float] | float) -> list[float] | float: @@ -115,90 +120,72 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v - @staticmethod - def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor: - """Crop the latent-space tensor to the area corresponding to the image-space tile. - The tile coordinates must be divisible by the LATENT_SCALE_FACTOR. - """ - for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]: - if coord % LATENT_SCALE_FACTOR != 0: - raise ValueError( - f"The tile coordinates must all be divisible by the latent scale factor" - f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}." - ) - assert latents.dim() == 4 # We expect: (batch_size, channels, height, width). - - top = image_tile.coords.top // LATENT_SCALE_FACTOR - left = image_tile.coords.left // LATENT_SCALE_FACTOR - bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR - right = image_tile.coords.right // LATENT_SCALE_FACTOR - return latents[..., top:bottom, left:right] - - def run_controlnet( - self, - image: Image.Image, - controlnet_model: ControlNetModel, - weight: float, - do_classifier_free_guidance: bool, - width: int, - height: int, - device: torch.device, - dtype: torch.dtype, - control_mode: CONTROLNET_MODE_VALUES = "balanced", - resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple", - ) -> ControlNetData: - control_image = prepare_control_image( - image=image, - do_classifier_free_guidance=do_classifier_free_guidance, - width=width, - height=height, - device=device, - dtype=dtype, - control_mode=control_mode, - resize_mode=resize_mode, - ) - return ControlNetData( - model=controlnet_model, - image_tensor=control_image, - weight=weight, - begin_step_percent=0.0, - end_step_percent=1.0, - control_mode=control_mode, - # Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to - # ControlNetData in case needed in the future. - resize_mode=resize_mode, + def _scale_tile(self, tile: Tile, scale: int) -> Tile: + """Scale the tile by the given factor.""" + return Tile( + coords=TBLR( + top=tile.coords.top * scale, + bottom=tile.coords.bottom * scale, + left=tile.coords.left * scale, + right=tile.coords.right * scale, + ), + overlap=TBLR( + top=tile.overlap.top * scale, + bottom=tile.overlap.bottom * scale, + left=tile.overlap.left * scale, + right=tile.overlap.right * scale, + ), ) @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - # TODO(ryand): Expose the seed parameter. - seed = 0 + # Convert tile image-space dimensions to latent-space dimensions. + latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR + latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR + latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR # Load the input image. input_image = context.images.get_pil(self.image.image_name) - - # Calculate the tile locations to cover the image. - # We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This - # facilitates conversions between image space and latent space. - # TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.) - tiles = calc_tiles_with_overlap( - image_height=input_image.height, - image_width=input_image.width, - tile_height=self.tile_height, - tile_width=self.tile_width, - overlap=self.tile_overlap, - ) - # Convert the input image to a torch.Tensor. input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR) input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension. # Validate our assumptions about the shape of input_image_torch. - assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width). - assert input_image_torch.shape[:2] == (1, 3) + batch_size, channels, image_height, image_width = input_image_torch.shape + assert batch_size == 1 + assert channels == 3 + + # Load the noise tensor. + noise = context.tensors.load(self.noise.latents_name) + if list(noise.shape) != [ + batch_size, + 4, + image_height // LATENT_SCALE_FACTOR, + image_width // LATENT_SCALE_FACTOR, + ]: + raise ValueError( + f"Incompatible noise and image dimensions. Image shape: {input_image_torch.shape}. " + f"Noise shape: {noise.shape}. Expected noise shape: [1, 1, " + f"{image_height // LATENT_SCALE_FACTOR}, {image_width // LATENT_SCALE_FACTOR}]. " + ) + latent_height, latent_width = noise.shape[2:] + + # Extract the seed from the noise field. + assert self.noise.seed is not None + seed = self.noise.seed or 0 + + # Calculate the tile locations in both latent space and image space. + latent_space_tiles = calc_tiles_min_overlap( + image_height=latent_height, + image_width=latent_width, + tile_height=latent_tile_height, + tile_width=latent_tile_width, + min_overlap=latent_tile_overlap, + ) + image_space_tiles = [self._scale_tile(tile, LATENT_SCALE_FACTOR) for tile in latent_space_tiles] # Split the input image into tiles in torch.Tensor format. image_tiles_torch: list[torch.Tensor] = [] - for tile in tiles: + for tile in image_space_tiles: image_tile = input_image_torch[ :, :, @@ -207,22 +194,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): ] image_tiles_torch.append(image_tile) - # Split the input image into tiles in numpy format. - # TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work - # with torch.Tensor tiles. - input_image_np = np.array(input_image) - image_tiles_np: list[npt.NDArray[np.uint8]] = [] - for tile in tiles: - image_tile_np = input_image_np[ - tile.coords.top : tile.coords.bottom, - tile.coords.left : tile.coords.right, - :, - ] - image_tiles_np.append(image_tile_np) - # VAE-encode each image tile independently. - # TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What - # about for decoding? vae_info = context.models.load(self.vae.vae) latent_tiles: list[torch.Tensor] = [] for image_tile_torch in image_tiles_torch: @@ -232,23 +204,16 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): ) ) - # Generate noise with dimensions corresponding to the full image in latent space. - # It is important that the noise tensor is generated at the full image dimension and then tiled, rather than - # generating for each tile independently. This ensures that overlapping regions between tiles use the same - # noise. - assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0 - assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0 - global_noise = get_noise( - width=input_image_torch.shape[3], - height=input_image_torch.shape[2], - device=TorchDevice.choose_torch_device(), - seed=seed, - downsampling_factor=LATENT_SCALE_FACTOR, - use_cpu=True, - ) - # Crop the global noise into tiles. - noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles] + noise_tiles: list[torch.Tensor] = [] + for tile in latent_space_tiles: + noise_tile = noise[ + :, + :, + tile.coords.top : tile.coords.bottom, + tile.coords.left : tile.coords.right, + ] + noise_tiles.append(noise_tile) # Prepare an iterator that yields the UNet's LoRA models and their weights. def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: @@ -273,53 +238,42 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler) # Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles. - # Assume that all tiles have the same shape. - _, _, latent_height, latent_width = latent_tiles[0].shape conditioning_data = DenoiseLatentsInvocation.get_conditioning_data( context=context, positive_conditioning_field=self.positive_conditioning, negative_conditioning_field=self.negative_conditioning, unet=unet, - latent_height=latent_height, - latent_width=latent_width, + latent_height=latent_tile_height, + latent_width=latent_tile_width, cfg_scale=self.cfg_scale, steps=self.steps, cfg_rescale_multiplier=self.cfg_rescale_multiplier, ) - # Load the ControlNet model. - # TODO(ryand): Support multiple ControlNet models. - controlnet_model = exit_stack.enter_context(context.models.load(self.control_model)) - assert isinstance(controlnet_model, ControlNetModel) + controlnet_data = DenoiseLatentsInvocation.prep_control_data( + context=context, + control_input=self.control, + # NOTE: We use the shape of the global noise tensor here, because this is a global ControlNet. We tile + # it later. + latents_shape=list(noise.shape), + # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + do_classifier_free_guidance=True, + exit_stack=exit_stack, + ) + + # Split the controlnet_data into tiles. + # controlnet_data_tiles[t][c] is the c'th control data for the t'th tile. + controlnet_data_tiles: list[list[ControlNetData]] = [] + for tile in latent_space_tiles: + tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []] + controlnet_data_tiles.append(tile_controlnet_data) # Denoise (i.e. "refine") each tile independently. - for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True): + for latent_tile, noise_tile, controlnet_data_tile in zip( + latent_tiles, noise_tiles, controlnet_data_tiles, strict=True + ): assert latent_tile.shape == noise_tile.shape - # Prepare a PIL Image for ControlNet processing. - # TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of - # the tiles. Ideally, the ControlNet code should be able to work with Tensors. - image_tile_pil = Image.fromarray(image_tile_np) - - # Run the ControlNet on the image tile. - height, width, _ = image_tile_np.shape - # The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we - # validate this assumption here. - assert height % LATENT_SCALE_FACTOR == 0 - assert width % LATENT_SCALE_FACTOR == 0 - controlnet_data = self.run_controlnet( - image=image_tile_pil, - controlnet_model=controlnet_model, - weight=self.control_weight, - do_classifier_free_guidance=True, - width=width, - height=height, - device=controlnet_model.device, - dtype=controlnet_model.dtype, - control_mode="balanced", - resize_mode="just_resize_simple", - ) - timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler( scheduler, device=unet.device, @@ -342,7 +296,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): masked_latents=None, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, - control_data=[controlnet_data], + control_data=controlnet_data_tile, ip_adapter_data=None, t2i_adapter_data=None, callback=lambda x: None, @@ -368,9 +322,11 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): # Merge the refined image tiles back into a single image. refined_image_tiles_np = [np.array(t) for t in refined_image_tiles] merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8) - # TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter? merge_tiles_with_linear_blending( - dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap + dst_image=merged_image_np, + tiles=image_space_tiles, + tile_images=refined_image_tiles_np, + blend_amount=self.tile_overlap, ) # Save the refined image and return its reference.