Refactor TiledStableDiffusionRefineInvocation to more closely mirror TiledMultiDiffusionDenoiseLatents. The biggest improvement is in the handling of the ControlNets - global ControlNet info can now be passed in and it is tiled within the node.

This commit is contained in:
Ryan Dick 2024-06-26 20:39:29 -04:00
parent b74bc77347
commit 8379feeb8a

View File

@ -2,14 +2,14 @@ from contextlib import ExitStack
from typing import Iterator, Tuple from typing import Iterator, Tuple
import numpy as np import numpy as np
import numpy.typing as npt
import torch import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image from PIL import Image
from pydantic import field_validator from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, Classification, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
ConditioningField, ConditioningField,
@ -17,22 +17,24 @@ from invokeai.app.invocations.fields import (
ImageField, ImageField,
Input, Input,
InputField, InputField,
LatentsField,
UIType, UIType,
) )
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.invocations.tiled_multi_diffusion_denoise_latents import crop_controlnet_data
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending from invokeai.backend.tiles.tiles import (
from invokeai.backend.tiles.utils import Tile calc_tiles_min_overlap,
merge_tiles_with_linear_blending,
)
from invokeai.backend.tiles.utils import TBLR, Tile
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@invocation( @invocation(
@ -40,6 +42,7 @@ from invokeai.backend.util.hotfixes import ControlNetModel
title="Tiled Stable Diffusion Refine", title="Tiled Stable Diffusion Refine",
tags=["upscale", "denoise"], tags=["upscale", "denoise"],
category="latents", category="latents",
classification=Classification.Beta,
version="1.0.0", version="1.0.0",
) )
class TiledStableDiffusionRefineInvocation(BaseInvocation): class TiledStableDiffusionRefineInvocation(BaseInvocation):
@ -55,13 +58,21 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
negative_conditioning: ConditioningField = InputField( negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection description=FieldDescriptions.negative_cond, input=Input.Connection
) )
# TODO(ryand): Add multiple-of validation. noise: LatentsField = InputField(
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.") description=FieldDescriptions.noise,
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.") input=Input.Connection,
)
tile_height: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Height of the tiles in image space."
)
tile_width: int = InputField(
default=1024, gt=0, multiple_of=LATENT_SCALE_FACTOR, description="Width of the tiles in image space."
)
tile_overlap: int = InputField( tile_overlap: int = InputField(
default=16, default=32,
multiple_of=LATENT_SCALE_FACTOR,
gt=0, gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).", description="Target overlap between adjacent tiles in image space.",
) )
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps) steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
@ -92,16 +103,10 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
vae_fp32: bool = InputField( vae_fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE." default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
) )
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly control: ControlField | list[ControlField] | None = InputField(
# don't want to use the image field. Figure out how best to handle this. default=None,
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty input=Input.Connection,
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
# etc.
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
) )
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale") @field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float: def ge_one(cls, v: list[float] | float) -> list[float] | float:
@ -115,90 +120,72 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
return v return v
@staticmethod def _scale_tile(self, tile: Tile, scale: int) -> Tile:
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor: """Scale the tile by the given factor."""
"""Crop the latent-space tensor to the area corresponding to the image-space tile. return Tile(
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR. coords=TBLR(
""" top=tile.coords.top * scale,
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]: bottom=tile.coords.bottom * scale,
if coord % LATENT_SCALE_FACTOR != 0: left=tile.coords.left * scale,
raise ValueError( right=tile.coords.right * scale,
f"The tile coordinates must all be divisible by the latent scale factor" ),
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}." overlap=TBLR(
) top=tile.overlap.top * scale,
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width). bottom=tile.overlap.bottom * scale,
left=tile.overlap.left * scale,
top = image_tile.coords.top // LATENT_SCALE_FACTOR right=tile.overlap.right * scale,
left = image_tile.coords.left // LATENT_SCALE_FACTOR ),
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
def run_controlnet(
self,
image: Image.Image,
controlnet_model: ControlNetModel,
weight: float,
do_classifier_free_guidance: bool,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
) -> ControlNetData:
control_image = prepare_control_image(
image=image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return ControlNetData(
model=controlnet_model,
image_tensor=control_image,
weight=weight,
begin_step_percent=0.0,
end_step_percent=1.0,
control_mode=control_mode,
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
# ControlNetData in case needed in the future.
resize_mode=resize_mode,
) )
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter. # Convert tile image-space dimensions to latent-space dimensions.
seed = 0 latent_tile_height = self.tile_height // LATENT_SCALE_FACTOR
latent_tile_width = self.tile_width // LATENT_SCALE_FACTOR
latent_tile_overlap = self.tile_overlap // LATENT_SCALE_FACTOR
# Load the input image. # Load the input image.
input_image = context.images.get_pil(self.image.image_name) input_image = context.images.get_pil(self.image.image_name)
# Calculate the tile locations to cover the image.
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This
# facilitates conversions between image space and latent space.
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
tiles = calc_tiles_with_overlap(
image_height=input_image.height,
image_width=input_image.width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.tile_overlap,
)
# Convert the input image to a torch.Tensor. # Convert the input image to a torch.Tensor.
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR) input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension. input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Validate our assumptions about the shape of input_image_torch. # Validate our assumptions about the shape of input_image_torch.
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width). batch_size, channels, image_height, image_width = input_image_torch.shape
assert input_image_torch.shape[:2] == (1, 3) assert batch_size == 1
assert channels == 3
# Load the noise tensor.
noise = context.tensors.load(self.noise.latents_name)
if list(noise.shape) != [
batch_size,
4,
image_height // LATENT_SCALE_FACTOR,
image_width // LATENT_SCALE_FACTOR,
]:
raise ValueError(
f"Incompatible noise and image dimensions. Image shape: {input_image_torch.shape}. "
f"Noise shape: {noise.shape}. Expected noise shape: [1, 1, "
f"{image_height // LATENT_SCALE_FACTOR}, {image_width // LATENT_SCALE_FACTOR}]. "
)
latent_height, latent_width = noise.shape[2:]
# Extract the seed from the noise field.
assert self.noise.seed is not None
seed = self.noise.seed or 0
# Calculate the tile locations in both latent space and image space.
latent_space_tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=latent_tile_height,
tile_width=latent_tile_width,
min_overlap=latent_tile_overlap,
)
image_space_tiles = [self._scale_tile(tile, LATENT_SCALE_FACTOR) for tile in latent_space_tiles]
# Split the input image into tiles in torch.Tensor format. # Split the input image into tiles in torch.Tensor format.
image_tiles_torch: list[torch.Tensor] = [] image_tiles_torch: list[torch.Tensor] = []
for tile in tiles: for tile in image_space_tiles:
image_tile = input_image_torch[ image_tile = input_image_torch[
:, :,
:, :,
@ -207,22 +194,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
] ]
image_tiles_torch.append(image_tile) image_tiles_torch.append(image_tile)
# Split the input image into tiles in numpy format.
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
# with torch.Tensor tiles.
input_image_np = np.array(input_image)
image_tiles_np: list[npt.NDArray[np.uint8]] = []
for tile in tiles:
image_tile_np = input_image_np[
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
:,
]
image_tiles_np.append(image_tile_np)
# VAE-encode each image tile independently. # VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae) vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = [] latent_tiles: list[torch.Tensor] = []
for image_tile_torch in image_tiles_torch: for image_tile_torch in image_tiles_torch:
@ -232,23 +204,16 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
) )
) )
# Generate noise with dimensions corresponding to the full image in latent space.
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(),
seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True,
)
# Crop the global noise into tiles. # Crop the global noise into tiles.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles] noise_tiles: list[torch.Tensor] = []
for tile in latent_space_tiles:
noise_tile = noise[
:,
:,
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
]
noise_tiles.append(noise_tile)
# Prepare an iterator that yields the UNet's LoRA models and their weights. # Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
@ -273,53 +238,42 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler) pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles. # Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data( conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context, context=context,
positive_conditioning_field=self.positive_conditioning, positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning, negative_conditioning_field=self.negative_conditioning,
unet=unet, unet=unet,
latent_height=latent_height, latent_height=latent_tile_height,
latent_width=latent_width, latent_width=latent_tile_width,
cfg_scale=self.cfg_scale, cfg_scale=self.cfg_scale,
steps=self.steps, steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier, cfg_rescale_multiplier=self.cfg_rescale_multiplier,
) )
# Load the ControlNet model. controlnet_data = DenoiseLatentsInvocation.prep_control_data(
# TODO(ryand): Support multiple ControlNet models. context=context,
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model)) control_input=self.control,
assert isinstance(controlnet_model, ControlNetModel) # NOTE: We use the shape of the global noise tensor here, because this is a global ControlNet. We tile
# it later.
latents_shape=list(noise.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in latent_space_tiles:
tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []]
controlnet_data_tiles.append(tile_controlnet_data)
# Denoise (i.e. "refine") each tile independently. # Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True): for latent_tile, noise_tile, controlnet_data_tile in zip(
latent_tiles, noise_tiles, controlnet_data_tiles, strict=True
):
assert latent_tile.shape == noise_tile.shape assert latent_tile.shape == noise_tile.shape
# Prepare a PIL Image for ControlNet processing.
# TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of
# the tiles. Ideally, the ControlNet code should be able to work with Tensors.
image_tile_pil = Image.fromarray(image_tile_np)
# Run the ControlNet on the image tile.
height, width, _ = image_tile_np.shape
# The height and width must be evenly divisible by LATENT_SCALE_FACTOR. This is enforced earlier, but we
# validate this assumption here.
assert height % LATENT_SCALE_FACTOR == 0
assert width % LATENT_SCALE_FACTOR == 0
controlnet_data = self.run_controlnet(
image=image_tile_pil,
controlnet_model=controlnet_model,
weight=self.control_weight,
do_classifier_free_guidance=True,
width=width,
height=height,
device=controlnet_model.device,
dtype=controlnet_model.dtype,
control_mode="balanced",
resize_mode="just_resize_simple",
)
timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler( timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler(
scheduler, scheduler,
device=unet.device, device=unet.device,
@ -342,7 +296,7 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
masked_latents=None, masked_latents=None,
scheduler_step_kwargs=scheduler_step_kwargs, scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=[controlnet_data], control_data=controlnet_data_tile,
ip_adapter_data=None, ip_adapter_data=None,
t2i_adapter_data=None, t2i_adapter_data=None,
callback=lambda x: None, callback=lambda x: None,
@ -368,9 +322,11 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
# Merge the refined image tiles back into a single image. # Merge the refined image tiles back into a single image.
refined_image_tiles_np = [np.array(t) for t in refined_image_tiles] refined_image_tiles_np = [np.array(t) for t in refined_image_tiles]
merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8) merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8)
# TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter?
merge_tiles_with_linear_blending( merge_tiles_with_linear_blending(
dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap dst_image=merged_image_np,
tiles=image_space_tiles,
tile_images=refined_image_tiles_np,
blend_amount=self.tile_overlap,
) )
# Save the refined image and return its reference. # Save the refined image and return its reference.