diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 0c56b336b8..ae86c164e1 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -2,37 +2,36 @@ from contextlib import ExitStack from typing import Iterator, Tuple import numpy as np -import numpy.typing as npt import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from PIL import Image from pydantic import field_validator from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation -from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES +from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES +from invokeai.app.invocations.controlnet_image_processors import ControlField from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler from invokeai.app.invocations.fields import ( ConditioningField, FieldDescriptions, - ImageField, Input, InputField, + LatentsField, UIType, ) -from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation -from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField +from invokeai.app.invocations.model import UNetField from invokeai.app.invocations.noise import get_noise from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher -from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor -from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending -from invokeai.backend.tiles.utils import Tile +from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData +from invokeai.backend.tiles.tiles import ( + calc_tiles_min_overlap, + merge_tiles_with_linear_blending, +) from invokeai.backend.util.devices import TorchDevice -from invokeai.backend.util.hotfixes import ControlNetModel @invocation( @@ -40,14 +39,19 @@ from invokeai.backend.util.hotfixes import ControlNetModel title="Tiled Stable Diffusion Refine", tags=["upscale", "denoise"], category="latents", + # TODO(ryand): Reset to 1.0.0 right before release. version="1.0.0", ) -class TiledStableDiffusionRefineInvocation(BaseInvocation): - """A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to - refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow. - """ +class TiledMultiDiffusionDenoiseLatents(BaseInvocation): + """Tiled Multi-Diffusion denoising. - image: ImageField = InputField(description="Image to be refined.") + This node handles automatically tiling the input image. Future iterations of + this node should allow the user to specify custom regions with different parameters for each region to harness the + full power of Multi-Diffusion. + + This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter, + T2I-Adapter, masking, etc.). + """ positive_conditioning: ConditioningField = InputField( description=FieldDescriptions.positive_cond, input=Input.Connection @@ -55,16 +59,29 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): negative_conditioning: ConditioningField = InputField( description=FieldDescriptions.negative_cond, input=Input.Connection ) + noise: LatentsField | None = InputField( + default=None, + description=FieldDescriptions.noise, + input=Input.Connection, + ) + latents: LatentsField | None = InputField( + default=None, + description=FieldDescriptions.latents, + input=Input.Connection, + ) # TODO(ryand): Add multiple-of validation. - tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.") - tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.") - tile_overlap: int = InputField( + # TODO(ryand): Smaller defaults might make more sense. + tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.") + tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.") + tile_min_overlap: int = InputField( default=16, gt=0, - description="Target overlap between adjacent tiles (the last row/column may overlap more than this).", + description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than " + "this to evenly cover the entire image.", ) steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps) cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") + # TODO(ryand): The default here should probably be 0.0. denoising_start: float = InputField( default=0.65, ge=0, @@ -85,23 +102,10 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): cfg_rescale_multiplier: float = InputField( title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier ) - vae: VAEField = InputField( - description=FieldDescriptions.vae, + control: ControlField | list[ControlField] | None = InputField( + default=None, input=Input.Connection, ) - vae_fp32: bool = InputField( - default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE." - ) - # HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly - # don't want to use the image field. Figure out how best to handle this. - # TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty - # much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support - # CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models, - # etc. - control_model: ModelIdentifierField = InputField( - description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel - ) - control_weight: float = InputField(default=0.6) @field_validator("cfg_scale") def ge_one(cls, v: list[float] | float) -> list[float] | float: @@ -115,140 +119,44 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v - @staticmethod - def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor: - """Crop the latent-space tensor to the area corresponding to the image-space tile. - The tile coordinates must be divisible by the LATENT_SCALE_FACTOR. - """ - for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]: - if coord % LATENT_SCALE_FACTOR != 0: - raise ValueError( - f"The tile coordinates must all be divisible by the latent scale factor" - f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}." - ) - assert latents.dim() == 4 # We expect: (batch_size, channels, height, width). - - top = image_tile.coords.top // LATENT_SCALE_FACTOR - left = image_tile.coords.left // LATENT_SCALE_FACTOR - bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR - right = image_tile.coords.right // LATENT_SCALE_FACTOR - return latents[..., top:bottom, left:right] - - def run_controlnet( - self, - image: Image.Image, - controlnet_model: ControlNetModel, - weight: float, - do_classifier_free_guidance: bool, - width: int, - height: int, - device: torch.device, - dtype: torch.dtype, - control_mode: CONTROLNET_MODE_VALUES = "balanced", - resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple", - ) -> ControlNetData: - control_image = prepare_control_image( - image=image, - do_classifier_free_guidance=do_classifier_free_guidance, - width=width, - height=height, - device=device, - dtype=dtype, - control_mode=control_mode, - resize_mode=resize_mode, - ) - return ControlNetData( - model=controlnet_model, - image_tensor=control_image, - weight=weight, - begin_step_percent=0.0, - end_step_percent=1.0, - control_mode=control_mode, - # Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to - # ControlNetData in case needed in the future. - resize_mode=resize_mode, - ) - @torch.no_grad() def invoke(self, context: InvocationContext) -> ImageOutput: - # TODO(ryand): Expose the seed parameter. - seed = 0 + seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents) + _, _, latent_height, latent_width = latents.shape - # Load the input image. - input_image = context.images.get_pil(self.image.image_name) - - # Calculate the tile locations to cover the image. - # We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This - # facilitates conversions between image space and latent space. - # TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.) - tiles = calc_tiles_with_overlap( - image_height=input_image.height, - image_width=input_image.width, - tile_height=self.tile_height, - tile_width=self.tile_width, - overlap=self.tile_overlap, - ) - - # Convert the input image to a torch.Tensor. - input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR) - input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension. - # Validate our assumptions about the shape of input_image_torch. - assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width). - assert input_image_torch.shape[:2] == (1, 3) - - # Split the input image into tiles in torch.Tensor format. - image_tiles_torch: list[torch.Tensor] = [] - for tile in tiles: - image_tile = input_image_torch[ - :, - :, - tile.coords.top : tile.coords.bottom, - tile.coords.left : tile.coords.right, - ] - image_tiles_torch.append(image_tile) - - # Split the input image into tiles in numpy format. - # TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work - # with torch.Tensor tiles. - input_image_np = np.array(input_image) - image_tiles_np: list[npt.NDArray[np.uint8]] = [] - for tile in tiles: - image_tile_np = input_image_np[ - tile.coords.top : tile.coords.bottom, - tile.coords.left : tile.coords.right, - :, - ] - image_tiles_np.append(image_tile_np) - - # VAE-encode each image tile independently. - # TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What - # about for decoding? - vae_info = context.models.load(self.vae.vae) - latent_tiles: list[torch.Tensor] = [] - for image_tile_torch in image_tiles_torch: - latent_tiles.append( - ImageToLatentsInvocation.vae_encode( - vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch - ) + # If noise is None, populate it here. + # TODO(ryand): Currently there is logic to generate noise deeper in the stack if it is None. We should just move + # that logic up the stack in all places that it's relied upon (i.e. do it in prepare_noise_and_latents). In this + # particular case, we want to make sure that the noise is generated globally rather than per-tile so that + # overlapping tile regions use the same noise. + if noise is None: + noise = get_noise( + width=latent_width * LATENT_SCALE_FACTOR, + height=latent_height * LATENT_SCALE_FACTOR, + device=TorchDevice.choose_torch_device(), + seed=seed, + downsampling_factor=LATENT_SCALE_FACTOR, + use_cpu=True, ) - # Generate noise with dimensions corresponding to the full image in latent space. - # It is important that the noise tensor is generated at the full image dimension and then tiled, rather than - # generating for each tile independently. This ensures that overlapping regions between tiles use the same - # noise. - assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0 - assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0 - global_noise = get_noise( - width=input_image_torch.shape[3], - height=input_image_torch.shape[2], - device=TorchDevice.choose_torch_device(), - seed=seed, - downsampling_factor=LATENT_SCALE_FACTOR, - use_cpu=True, + # Calculate the tile locations to cover the latent-space image. + # TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint? + tiles = calc_tiles_min_overlap( + image_height=latent_height, + image_width=latent_width, + tile_height=self.tile_height, + tile_width=self.tile_width, + min_overlap=self.tile_min_overlap, ) - # Crop the global noise into tiles. - noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles] + # Split the noise and latents into tiles. + noise_tiles: list[torch.Tensor] = [] + latent_tiles: list[torch.Tensor] = [] + for tile in tiles: + noise_tile = noise[..., tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right] + latent_tile = latents[..., tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right] + noise_tiles.append(noise_tile) + latent_tiles.append(latent_tile) # Prepare an iterator that yields the UNet's LoRA models and their weights. def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: @@ -273,25 +181,54 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation): pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler) # Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles. - # Assume that all tiles have the same shape. - _, _, latent_height, latent_width = latent_tiles[0].shape conditioning_data = DenoiseLatentsInvocation.get_conditioning_data( context=context, positive_conditioning_field=self.positive_conditioning, negative_conditioning_field=self.negative_conditioning, unet=unet, - latent_height=latent_height, - latent_width=latent_width, + latent_height=self.tile_height, + latent_width=self.tile_width, cfg_scale=self.cfg_scale, steps=self.steps, cfg_rescale_multiplier=self.cfg_rescale_multiplier, ) - # Load the ControlNet model. - # TODO(ryand): Support multiple ControlNet models. - controlnet_model = exit_stack.enter_context(context.models.load(self.control_model)) - assert isinstance(controlnet_model, ControlNetModel) + controlnet_data = DenoiseLatentsInvocation.prep_control_data( + context=context, + control_input=self.control, + latents_shape=list(latents.shape), + # do_classifier_free_guidance=(self.cfg_scale >= 1.0)) + do_classifier_free_guidance=True, + exit_stack=exit_stack, + ) + # Split the controlnet_data into tiles. + if controlnet_data is not None: + # controlnet_data_tiles[t][c] is the c'th control data for the t'th tile. + controlnet_data_tiles: list[list[ControlNetData]] = [] + for tile in tiles: + # To split the controlnet_data into tiles, we simply need to crop each image_tensor. All other + # params can be copied unmodified. + tile_controlnet_data = [ + ControlNetData( + model=cn.model, + image_tensor=cn.image_tensor[ + :, + :, + tile.coords.top * LATENT_SCALE_FACTOR : tile.coords.bottom * LATENT_SCALE_FACTOR, + tile.coords.left * LATENT_SCALE_FACTOR : tile.coords.right * LATENT_SCALE_FACTOR, + ], + weight=cn.weight, + begin_step_percent=cn.begin_step_percent, + end_step_percent=cn.end_step_percent, + control_mode=cn.control_mode, + resize_mode=cn.resize_mode, + ) + for cn in controlnet_data + ] + controlnet_data_tiles.append(tile_controlnet_data) + + # TODO(ryand): Logic from here down needs updating -------------------- # Denoise (i.e. "refine") each tile independently. for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True): assert latent_tile.shape == noise_tile.shape