WIP TiledMultiDiffusionDenoiseLatents. Updated parameter list and first half of the logic.

This commit is contained in:
Ryan Dick 2024-06-12 11:49:23 -04:00 committed by Kent Keirsey
parent 7e94350351
commit 230e205541

View File

@ -2,37 +2,36 @@ from contextlib import ExitStack
from typing import Iterator, Tuple from typing import Iterator, Tuple
import numpy as np import numpy as np
import numpy.typing as npt
import torch import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image from PIL import Image
from pydantic import field_validator from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES from invokeai.app.invocations.constants import LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
ConditioningField, ConditioningField,
FieldDescriptions, FieldDescriptions,
ImageField,
Input, Input,
InputField, InputField,
LatentsField,
UIType, UIType,
) )
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import ModelIdentifierField, UNetField, VAEField from invokeai.app.invocations.model import UNetField
from invokeai.app.invocations.noise import get_noise from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import ImageOutput from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES, prepare_control_image
from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.lora import LoRAModelRaw
from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.model_patcher import ModelPatcher
from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData, image_resized_to_grid_as_tensor from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData
from invokeai.backend.tiles.tiles import calc_tiles_with_overlap, merge_tiles_with_linear_blending from invokeai.backend.tiles.tiles import (
from invokeai.backend.tiles.utils import Tile calc_tiles_min_overlap,
merge_tiles_with_linear_blending,
)
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.hotfixes import ControlNetModel
@invocation( @invocation(
@ -40,14 +39,19 @@ from invokeai.backend.util.hotfixes import ControlNetModel
title="Tiled Stable Diffusion Refine", title="Tiled Stable Diffusion Refine",
tags=["upscale", "denoise"], tags=["upscale", "denoise"],
category="latents", category="latents",
# TODO(ryand): Reset to 1.0.0 right before release.
version="1.0.0", version="1.0.0",
) )
class TiledStableDiffusionRefineInvocation(BaseInvocation): class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to """Tiled Multi-Diffusion denoising.
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
"""
image: ImageField = InputField(description="Image to be refined.") This node handles automatically tiling the input image. Future iterations of
this node should allow the user to specify custom regions with different parameters for each region to harness the
full power of Multi-Diffusion.
This node has a similar interface to the `DenoiseLatents` node, but it has a reduced feature set (no IP-Adapter,
T2I-Adapter, masking, etc.).
"""
positive_conditioning: ConditioningField = InputField( positive_conditioning: ConditioningField = InputField(
description=FieldDescriptions.positive_cond, input=Input.Connection description=FieldDescriptions.positive_cond, input=Input.Connection
@ -55,16 +59,29 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
negative_conditioning: ConditioningField = InputField( negative_conditioning: ConditioningField = InputField(
description=FieldDescriptions.negative_cond, input=Input.Connection description=FieldDescriptions.negative_cond, input=Input.Connection
) )
noise: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.noise,
input=Input.Connection,
)
latents: LatentsField | None = InputField(
default=None,
description=FieldDescriptions.latents,
input=Input.Connection,
)
# TODO(ryand): Add multiple-of validation. # TODO(ryand): Add multiple-of validation.
tile_height: int = InputField(default=512, gt=0, description="Height of the tiles.") # TODO(ryand): Smaller defaults might make more sense.
tile_width: int = InputField(default=512, gt=0, description="Width of the tiles.") tile_height: int = InputField(default=112, gt=0, description="Height of the tiles in latent space.")
tile_overlap: int = InputField( tile_width: int = InputField(default=112, gt=0, description="Width of the tiles in latent space.")
tile_min_overlap: int = InputField(
default=16, default=16,
gt=0, gt=0,
description="Target overlap between adjacent tiles (the last row/column may overlap more than this).", description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
"this to evenly cover the entire image.",
) )
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps) steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
# TODO(ryand): The default here should probably be 0.0.
denoising_start: float = InputField( denoising_start: float = InputField(
default=0.65, default=0.65,
ge=0, ge=0,
@ -85,23 +102,10 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
cfg_rescale_multiplier: float = InputField( cfg_rescale_multiplier: float = InputField(
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
) )
vae: VAEField = InputField( control: ControlField | list[ControlField] | None = InputField(
description=FieldDescriptions.vae, default=None,
input=Input.Connection, input=Input.Connection,
) )
vae_fp32: bool = InputField(
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
)
# HACK(ryand): We probably want to allow the user to control all of the parameters in ControlField. But, we akwardly
# don't want to use the image field. Figure out how best to handle this.
# TODO(ryand): Currently, there is no ControlNet preprocessor applied to the tile images. In other words, we pretty
# much assume that it is a tile ControlNet. We need to decide how we want to handle this. E.g. find a way to support
# CN preprocessors, raise a clear warning when a non-tile CN model is selected, hardcode the supported CN models,
# etc.
control_model: ModelIdentifierField = InputField(
description=FieldDescriptions.controlnet_model, ui_type=UIType.ControlNetModel
)
control_weight: float = InputField(default=0.6)
@field_validator("cfg_scale") @field_validator("cfg_scale")
def ge_one(cls, v: list[float] | float) -> list[float] | float: def ge_one(cls, v: list[float] | float) -> list[float] | float:
@ -115,140 +119,44 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
return v return v
@staticmethod
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
"""
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
if coord % LATENT_SCALE_FACTOR != 0:
raise ValueError(
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim() == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
def run_controlnet(
self,
image: Image.Image,
controlnet_model: ControlNetModel,
weight: float,
do_classifier_free_guidance: bool,
width: int,
height: int,
device: torch.device,
dtype: torch.dtype,
control_mode: CONTROLNET_MODE_VALUES = "balanced",
resize_mode: CONTROLNET_RESIZE_VALUES = "just_resize_simple",
) -> ControlNetData:
control_image = prepare_control_image(
image=image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=width,
height=height,
device=device,
dtype=dtype,
control_mode=control_mode,
resize_mode=resize_mode,
)
return ControlNetData(
model=controlnet_model,
image_tensor=control_image,
weight=weight,
begin_step_percent=0.0,
end_step_percent=1.0,
control_mode=control_mode,
# Any resizing needed should currently be happening in prepare_control_image(), but adding resize_mode to
# ControlNetData in case needed in the future.
resize_mode=resize_mode,
)
@torch.no_grad() @torch.no_grad()
def invoke(self, context: InvocationContext) -> ImageOutput: def invoke(self, context: InvocationContext) -> ImageOutput:
# TODO(ryand): Expose the seed parameter. seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents)
seed = 0 _, _, latent_height, latent_width = latents.shape
# Load the input image. # If noise is None, populate it here.
input_image = context.images.get_pil(self.image.image_name) # TODO(ryand): Currently there is logic to generate noise deeper in the stack if it is None. We should just move
# that logic up the stack in all places that it's relied upon (i.e. do it in prepare_noise_and_latents). In this
# Calculate the tile locations to cover the image. # particular case, we want to make sure that the noise is generated globally rather than per-tile so that
# We have selected this tiling strategy to make it easy to achieve tile coords that are multiples of 8. This # overlapping tile regions use the same noise.
# facilitates conversions between image space and latent space. if noise is None:
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.) noise = get_noise(
tiles = calc_tiles_with_overlap( width=latent_width * LATENT_SCALE_FACTOR,
image_height=input_image.height, height=latent_height * LATENT_SCALE_FACTOR,
image_width=input_image.width,
tile_height=self.tile_height,
tile_width=self.tile_width,
overlap=self.tile_overlap,
)
# Convert the input image to a torch.Tensor.
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
input_image_torch = input_image_torch.unsqueeze(0) # Add a batch dimension.
# Validate our assumptions about the shape of input_image_torch.
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
assert input_image_torch.shape[:2] == (1, 3)
# Split the input image into tiles in torch.Tensor format.
image_tiles_torch: list[torch.Tensor] = []
for tile in tiles:
image_tile = input_image_torch[
:,
:,
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
]
image_tiles_torch.append(image_tile)
# Split the input image into tiles in numpy format.
# TODO(ryand): We currently maintain both np.ndarray and torch.Tensor tiles. Ideally, all operations should work
# with torch.Tensor tiles.
input_image_np = np.array(input_image)
image_tiles_np: list[npt.NDArray[np.uint8]] = []
for tile in tiles:
image_tile_np = input_image_np[
tile.coords.top : tile.coords.bottom,
tile.coords.left : tile.coords.right,
:,
]
image_tiles_np.append(image_tile_np)
# VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = []
for image_tile_torch in image_tiles_torch:
latent_tiles.append(
ImageToLatentsInvocation.vae_encode(
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile_torch
)
)
# Generate noise with dimensions corresponding to the full image in latent space.
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(), device=TorchDevice.choose_torch_device(),
seed=seed, seed=seed,
downsampling_factor=LATENT_SCALE_FACTOR, downsampling_factor=LATENT_SCALE_FACTOR,
use_cpu=True, use_cpu=True,
) )
# Crop the global noise into tiles. # Calculate the tile locations to cover the latent-space image.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles] # TODO(ryand): Add constraints on the tile params. Is there a multiple-of constraint?
tiles = calc_tiles_min_overlap(
image_height=latent_height,
image_width=latent_width,
tile_height=self.tile_height,
tile_width=self.tile_width,
min_overlap=self.tile_min_overlap,
)
# Split the noise and latents into tiles.
noise_tiles: list[torch.Tensor] = []
latent_tiles: list[torch.Tensor] = []
for tile in tiles:
noise_tile = noise[..., tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
latent_tile = latents[..., tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right]
noise_tiles.append(noise_tile)
latent_tiles.append(latent_tile)
# Prepare an iterator that yields the UNet's LoRA models and their weights. # Prepare an iterator that yields the UNet's LoRA models and their weights.
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
@ -273,25 +181,54 @@ class TiledStableDiffusionRefineInvocation(BaseInvocation):
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler) pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles. # Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data( conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context, context=context,
positive_conditioning_field=self.positive_conditioning, positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning, negative_conditioning_field=self.negative_conditioning,
unet=unet, unet=unet,
latent_height=latent_height, latent_height=self.tile_height,
latent_width=latent_width, latent_width=self.tile_width,
cfg_scale=self.cfg_scale, cfg_scale=self.cfg_scale,
steps=self.steps, steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier, cfg_rescale_multiplier=self.cfg_rescale_multiplier,
) )
# Load the ControlNet model. controlnet_data = DenoiseLatentsInvocation.prep_control_data(
# TODO(ryand): Support multiple ControlNet models. context=context,
controlnet_model = exit_stack.enter_context(context.models.load(self.control_model)) control_input=self.control,
assert isinstance(controlnet_model, ControlNetModel) latents_shape=list(latents.shape),
# do_classifier_free_guidance=(self.cfg_scale >= 1.0))
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
# Split the controlnet_data into tiles.
if controlnet_data is not None:
# controlnet_data_tiles[t][c] is the c'th control data for the t'th tile.
controlnet_data_tiles: list[list[ControlNetData]] = []
for tile in tiles:
# To split the controlnet_data into tiles, we simply need to crop each image_tensor. All other
# params can be copied unmodified.
tile_controlnet_data = [
ControlNetData(
model=cn.model,
image_tensor=cn.image_tensor[
:,
:,
tile.coords.top * LATENT_SCALE_FACTOR : tile.coords.bottom * LATENT_SCALE_FACTOR,
tile.coords.left * LATENT_SCALE_FACTOR : tile.coords.right * LATENT_SCALE_FACTOR,
],
weight=cn.weight,
begin_step_percent=cn.begin_step_percent,
end_step_percent=cn.end_step_percent,
control_mode=cn.control_mode,
resize_mode=cn.resize_mode,
)
for cn in controlnet_data
]
controlnet_data_tiles.append(tile_controlnet_data)
# TODO(ryand): Logic from here down needs updating --------------------
# Denoise (i.e. "refine") each tile independently. # Denoise (i.e. "refine") each tile independently.
for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True): for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape assert latent_tile.shape == noise_tile.shape