2024-06-06 21:39:34 +00:00
|
|
|
import torch
|
|
|
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
2024-06-07 16:06:35 +00:00
|
|
|
from PIL import Image
|
2024-06-06 21:39:34 +00:00
|
|
|
from pydantic import field_validator
|
|
|
|
|
|
|
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
|
|
|
|
from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES
|
2024-06-07 16:06:35 +00:00
|
|
|
from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler
|
2024-06-06 21:39:34 +00:00
|
|
|
from invokeai.app.invocations.fields import (
|
|
|
|
ConditioningField,
|
|
|
|
FieldDescriptions,
|
|
|
|
ImageField,
|
|
|
|
Input,
|
|
|
|
InputField,
|
|
|
|
LatentsField,
|
|
|
|
UIType,
|
|
|
|
)
|
|
|
|
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
|
2024-06-07 16:06:35 +00:00
|
|
|
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
|
2024-06-06 21:39:34 +00:00
|
|
|
from invokeai.app.invocations.model import UNetField, VAEField
|
|
|
|
from invokeai.app.invocations.noise import get_noise
|
|
|
|
from invokeai.app.invocations.primitives import LatentsOutput
|
|
|
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
|
|
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
|
|
|
|
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
|
2024-06-07 16:06:35 +00:00
|
|
|
from invokeai.backend.tiles.utils import Tile
|
2024-06-06 21:39:34 +00:00
|
|
|
from invokeai.backend.util.devices import TorchDevice
|
|
|
|
|
|
|
|
|
|
|
|
@invocation(
|
|
|
|
"tiled_stable_diffusion_refine",
|
|
|
|
title="Tiled Stable Diffusion Refine",
|
|
|
|
tags=["upscale", "denoise"],
|
|
|
|
category="latents",
|
|
|
|
version="1.0.0",
|
|
|
|
)
|
|
|
|
class TiledStableDiffusionRefine(BaseInvocation):
|
|
|
|
"""A tiled Stable Diffusion pipeline for refining high resolution images. This invocation is intended to be used to
|
|
|
|
refine an image after upscaling i.e. it is the second step in a typical "tiled upscaling" workflow.
|
|
|
|
"""
|
|
|
|
|
|
|
|
# Implementation order:
|
|
|
|
# - Basic tiled denoising. Support text prompts, but no other features.
|
|
|
|
# - Support LoRA + TI
|
|
|
|
# - Support ControlNet
|
|
|
|
# - IP-Adapter? (It has to run on each tile independently. Could be complicated to support batching.)
|
|
|
|
|
|
|
|
image: ImageField = InputField(description="Image to be refined.")
|
|
|
|
|
2024-06-07 14:01:50 +00:00
|
|
|
positive_conditioning: ConditioningField = InputField(
|
2024-06-06 21:39:34 +00:00
|
|
|
description=FieldDescriptions.positive_cond, input=Input.Connection
|
|
|
|
)
|
2024-06-07 14:01:50 +00:00
|
|
|
negative_conditioning: ConditioningField = InputField(
|
2024-06-06 21:39:34 +00:00
|
|
|
description=FieldDescriptions.negative_cond, input=Input.Connection
|
|
|
|
)
|
|
|
|
noise: LatentsField | None = InputField(
|
|
|
|
default=None,
|
|
|
|
description=FieldDescriptions.noise,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
|
|
|
steps: int = InputField(default=10, gt=0, description=FieldDescriptions.steps)
|
|
|
|
cfg_scale: float | list[float] = InputField(default=7.5, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
|
|
|
denoising_start: float = InputField(
|
|
|
|
default=0.0,
|
|
|
|
ge=0,
|
|
|
|
le=1,
|
|
|
|
description=FieldDescriptions.denoising_start,
|
|
|
|
)
|
|
|
|
denoising_end: float = InputField(default=1.0, ge=0, le=1, description=FieldDescriptions.denoising_end)
|
|
|
|
scheduler: SCHEDULER_NAME_VALUES = InputField(
|
|
|
|
default="euler",
|
|
|
|
description=FieldDescriptions.scheduler,
|
|
|
|
ui_type=UIType.Scheduler,
|
|
|
|
)
|
|
|
|
unet: UNetField = InputField(
|
|
|
|
description=FieldDescriptions.unet,
|
|
|
|
input=Input.Connection,
|
|
|
|
title="UNet",
|
|
|
|
)
|
|
|
|
# control: Optional[Union[ControlField, list[ControlField]]] = InputField(
|
|
|
|
# default=None,
|
|
|
|
# input=Input.Connection,
|
|
|
|
# )
|
|
|
|
cfg_rescale_multiplier: float = InputField(
|
|
|
|
title="CFG Rescale Multiplier", default=0, ge=0, lt=1, description=FieldDescriptions.cfg_rescale_multiplier
|
|
|
|
)
|
|
|
|
latents: LatentsField | None = InputField(
|
|
|
|
default=None,
|
|
|
|
description=FieldDescriptions.latents,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
|
|
|
vae: VAEField = InputField(
|
|
|
|
description=FieldDescriptions.vae,
|
|
|
|
input=Input.Connection,
|
|
|
|
)
|
|
|
|
vae_fp32: bool = InputField(
|
|
|
|
default=DEFAULT_PRECISION == torch.float32, description="Whether to use float32 precision when running the VAE."
|
|
|
|
)
|
|
|
|
|
|
|
|
@field_validator("cfg_scale")
|
|
|
|
def ge_one(cls, v: list[float] | float) -> list[float] | float:
|
|
|
|
"""Validate that all cfg_scale values are >= 1"""
|
|
|
|
if isinstance(v, list):
|
|
|
|
for i in v:
|
|
|
|
if i < 1:
|
|
|
|
raise ValueError("cfg_scale must be greater than 1")
|
|
|
|
else:
|
|
|
|
if v < 1:
|
|
|
|
raise ValueError("cfg_scale must be greater than 1")
|
|
|
|
return v
|
|
|
|
|
2024-06-07 16:06:35 +00:00
|
|
|
@staticmethod
|
|
|
|
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
|
|
|
|
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
|
|
|
|
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
|
|
|
|
"""
|
|
|
|
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
|
|
|
|
if coord % LATENT_SCALE_FACTOR != 0:
|
|
|
|
raise ValueError(
|
|
|
|
f"The tile coordinates must all be divisible by the latent scale factor"
|
|
|
|
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
|
|
|
|
)
|
|
|
|
assert latents.dim == 4 # We expect: (batch_size, channels, height, width).
|
|
|
|
|
|
|
|
top = image_tile.coords.top // LATENT_SCALE_FACTOR
|
|
|
|
left = image_tile.coords.left // LATENT_SCALE_FACTOR
|
|
|
|
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
|
|
|
|
right = image_tile.coords.right // LATENT_SCALE_FACTOR
|
|
|
|
return latents[..., top:bottom, left:right]
|
|
|
|
|
2024-06-06 21:39:34 +00:00
|
|
|
@torch.no_grad()
|
|
|
|
def invoke(self, context: InvocationContext) -> LatentsOutput:
|
|
|
|
# TODO(ryand): Expose the seed parameter.
|
|
|
|
seed = 0
|
|
|
|
|
|
|
|
# Load the input image.
|
|
|
|
input_image = context.images.get_pil(self.image.image_name)
|
|
|
|
input_image_torch = image_resized_to_grid_as_tensor(input_image.convert("RGB"), multiple_of=LATENT_SCALE_FACTOR)
|
|
|
|
|
|
|
|
# Calculate the tile locations to cover the image.
|
|
|
|
# TODO(ryand): Expose these tiling parameters. (Keep in mind the multiple-of constraints on these params.)
|
|
|
|
tiles = calc_tiles_min_overlap(
|
|
|
|
image_height=input_image.height,
|
|
|
|
image_width=input_image.width,
|
|
|
|
tile_height=512,
|
|
|
|
tile_width=512,
|
|
|
|
min_overlap=128,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Validate our assumptions about the shape of input_image_torch.
|
|
|
|
assert input_image_torch.dim() == 4 # We expect: (batch_size, channels, height, width).
|
|
|
|
assert input_image_torch.shape[:2] == (1, 3)
|
|
|
|
|
|
|
|
# Split the input image into tiles.
|
|
|
|
image_tiles: list[torch.Tensor] = []
|
|
|
|
for tile in tiles:
|
|
|
|
image_tile = input_image_torch[
|
|
|
|
:,
|
|
|
|
:,
|
|
|
|
tile.coords.top : tile.coords.bottom,
|
|
|
|
tile.coords.left : tile.coords.right,
|
|
|
|
]
|
|
|
|
image_tiles.append(image_tile)
|
|
|
|
|
|
|
|
# VAE-encode each image tile independently.
|
2024-06-07 16:06:35 +00:00
|
|
|
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
|
|
|
|
# about for decoding?
|
2024-06-06 21:39:34 +00:00
|
|
|
vae_info = context.models.load(self.vae.vae)
|
|
|
|
latent_tiles: list[torch.Tensor] = []
|
|
|
|
for image_tile in image_tiles:
|
|
|
|
latent_tiles.append(
|
|
|
|
ImageToLatentsInvocation.vae_encode(
|
|
|
|
vae_info=vae_info, upcast=self.vae_fp32, tiled=False, image_tensor=image_tile
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
# Generate noise with dimensions corresponding to the full image in latent space.
|
|
|
|
# It is important that the noise tensor is generated at the full image dimension and then tiled, rather than
|
|
|
|
# generating for each tile independently. This ensures that overlapping regions between tiles use the same
|
|
|
|
# noise.
|
|
|
|
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
|
|
|
|
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
|
2024-06-07 16:06:35 +00:00
|
|
|
global_noise = get_noise(
|
2024-06-06 21:39:34 +00:00
|
|
|
width=input_image_torch.shape[3],
|
|
|
|
height=input_image_torch.shape[2],
|
|
|
|
device=TorchDevice.choose_torch_device(),
|
|
|
|
seed=seed,
|
|
|
|
downsampling_factor=LATENT_SCALE_FACTOR,
|
|
|
|
use_cpu=True,
|
|
|
|
)
|
|
|
|
|
2024-06-07 16:06:35 +00:00
|
|
|
# Crop the global noise into tiles.
|
|
|
|
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
|
|
|
|
|
2024-06-06 21:39:34 +00:00
|
|
|
# Load the UNet model.
|
|
|
|
unet_info = context.models.load(self.unet.unet)
|
|
|
|
|
|
|
|
refined_latent_tiles: list[torch.Tensor] = []
|
|
|
|
with unet_info as unet:
|
|
|
|
assert isinstance(unet, UNet2DConditionModel)
|
|
|
|
scheduler = get_scheduler(
|
|
|
|
context=context,
|
|
|
|
scheduler_info=self.unet.scheduler,
|
|
|
|
scheduler_name=self.scheduler,
|
|
|
|
seed=seed,
|
|
|
|
)
|
2024-06-07 16:06:35 +00:00
|
|
|
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
|
|
|
|
|
|
|
|
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
|
|
|
|
# Assume that all tiles have the same shape.
|
|
|
|
_, _, latent_height, latent_width = latent_tiles[0].shape
|
|
|
|
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
|
|
|
|
context=context,
|
|
|
|
positive_conditioning_field=self.positive_conditioning,
|
|
|
|
negative_conditioning_field=self.negative_conditioning,
|
|
|
|
unet=unet,
|
|
|
|
latent_height=latent_height,
|
|
|
|
latent_width=latent_width,
|
|
|
|
cfg_scale=self.cfg_scale,
|
|
|
|
steps=self.steps,
|
|
|
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Denoise (i.e. "refine") each tile independently.
|
|
|
|
for latent_tile, noise_tile in zip(latent_tiles, noise_tiles, strict=True):
|
|
|
|
assert latent_tile.shape == noise_tile.shape
|
|
|
|
|
|
|
|
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = (
|
|
|
|
DenoiseLatentsInvocation.init_scheduler(
|
|
|
|
scheduler,
|
|
|
|
device=unet.device,
|
|
|
|
steps=self.steps,
|
|
|
|
denoising_start=self.denoising_start,
|
|
|
|
denoising_end=self.denoising_end,
|
|
|
|
seed=seed,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
refined_latent_tile = pipeline.latents_from_embeddings(
|
|
|
|
latents=latent_tile,
|
|
|
|
timesteps=timesteps,
|
|
|
|
init_timestep=init_timestep,
|
|
|
|
noise=noise_tile,
|
|
|
|
seed=seed,
|
|
|
|
mask=None,
|
|
|
|
masked_latents=None,
|
|
|
|
gradient_mask=None,
|
|
|
|
num_inference_steps=num_inference_steps,
|
|
|
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
|
|
|
conditioning_data=conditioning_data,
|
|
|
|
control_data=None,
|
|
|
|
ip_adapter_data=None,
|
|
|
|
t2i_adapter_data=None,
|
|
|
|
callback=lambda x: None,
|
|
|
|
)
|
|
|
|
refined_latent_tiles.append(refined_latent_tile)
|
|
|
|
|
|
|
|
# VAE-decode each refined latent tile independently.
|
|
|
|
refined_image_tiles: list[Image.Image] = []
|
|
|
|
for refined_latent_tile in refined_latent_tiles:
|
|
|
|
refined_image_tile = LatentsToImageInvocation.vae_decode(
|
|
|
|
context=context,
|
|
|
|
vae_info=vae_info,
|
|
|
|
seamless_axes=self.vae.seamless_axes,
|
|
|
|
latents=refined_latent_tile,
|
|
|
|
use_fp32=self.vae_fp32,
|
|
|
|
use_tiling=False,
|
|
|
|
)
|
|
|
|
refined_image_tiles.append(refined_image_tile)
|
2024-06-06 21:39:34 +00:00
|
|
|
|
2024-06-07 16:06:35 +00:00
|
|
|
# Merge the refined image tiles back into a single image.
|
|
|
|
...
|
2024-06-06 21:39:34 +00:00
|
|
|
|
2024-06-07 16:06:35 +00:00
|
|
|
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
|
|
|
|
result_latents = result_latents.to("cpu")
|
|
|
|
TorchDevice.empty_cache()
|
2024-06-06 21:39:34 +00:00
|
|
|
|
|
|
|
name = context.tensors.save(tensor=result_latents)
|
|
|
|
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)
|