WIP - TiledStableDiffusionRefine

This commit is contained in:
Ryan Dick 2024-06-07 12:06:35 -04:00
parent bb5648983f
commit 787e1bbb5f

View File

@ -1,5 +1,6 @@
import torch
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from PIL import Image
from pydantic import field_validator
from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation
@ -15,12 +16,14 @@ from invokeai.app.invocations.fields import (
)
from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation
from invokeai.app.invocations.latent import DenoiseLatentsInvocation, get_scheduler
from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation
from invokeai.app.invocations.model import UNetField, VAEField
from invokeai.app.invocations.noise import get_noise
from invokeai.app.invocations.primitives import LatentsOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor
from invokeai.backend.tiles.tiles import calc_tiles_min_overlap
from invokeai.backend.tiles.utils import Tile
from invokeai.backend.util.devices import TorchDevice
@ -106,6 +109,25 @@ class TiledStableDiffusionRefine(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor:
"""Crop the latent-space tensor to the area corresponding to the image-space tile.
The tile coordinates must be divisible by the LATENT_SCALE_FACTOR.
"""
for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]:
if coord % LATENT_SCALE_FACTOR != 0:
raise ValueError(
f"The tile coordinates must all be divisible by the latent scale factor"
f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}."
)
assert latents.dim == 4 # We expect: (batch_size, channels, height, width).
top = image_tile.coords.top // LATENT_SCALE_FACTOR
left = image_tile.coords.left // LATENT_SCALE_FACTOR
bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR
right = image_tile.coords.right // LATENT_SCALE_FACTOR
return latents[..., top:bottom, left:right]
@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# TODO(ryand): Expose the seed parameter.
@ -141,7 +163,8 @@ class TiledStableDiffusionRefine(BaseInvocation):
image_tiles.append(image_tile)
# VAE-encode each image tile independently.
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles?
# TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What
# about for decoding?
vae_info = context.models.load(self.vae.vae)
latent_tiles: list[torch.Tensor] = []
for image_tile in image_tiles:
@ -157,7 +180,7 @@ class TiledStableDiffusionRefine(BaseInvocation):
# noise.
assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0
assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0
noise_tiles = get_noise(
global_noise = get_noise(
width=input_image_torch.shape[3],
height=input_image_torch.shape[2],
device=TorchDevice.choose_torch_device(),
@ -166,6 +189,9 @@ class TiledStableDiffusionRefine(BaseInvocation):
use_cpu=True,
)
# Crop the global noise into tiles.
noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles]
# Load the UNet model.
unet_info = context.models.load(self.unet.unet)
@ -179,9 +205,75 @@ class TiledStableDiffusionRefine(BaseInvocation):
seed=seed,
)
pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler)
for latent_tile in latent_tiles:
# Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles.
# Assume that all tiles have the same shape.
_, _, latent_height, latent_width = latent_tiles[0].shape
conditioning_data = DenoiseLatentsInvocation.get_conditioning_data(
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
# Denoise (i.e. "refine") each tile independently.
for latent_tile, noise_tile in zip(latent_tiles, noise_tiles, strict=True):
assert latent_tile.shape == noise_tile.shape
num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = (
DenoiseLatentsInvocation.init_scheduler(
scheduler,
device=unet.device,
steps=self.steps,
denoising_start=self.denoising_start,
denoising_end=self.denoising_end,
seed=seed,
)
)
refined_latent_tile = pipeline.latents_from_embeddings(
latents=latent_tile,
timesteps=timesteps,
init_timestep=init_timestep,
noise=noise_tile,
seed=seed,
mask=None,
masked_latents=None,
gradient_mask=None,
num_inference_steps=num_inference_steps,
scheduler_step_kwargs=scheduler_step_kwargs,
conditioning_data=conditioning_data,
control_data=None,
ip_adapter_data=None,
t2i_adapter_data=None,
callback=lambda x: None,
)
refined_latent_tiles.append(refined_latent_tile)
# VAE-decode each refined latent tile independently.
refined_image_tiles: list[Image.Image] = []
for refined_latent_tile in refined_latent_tiles:
refined_image_tile = LatentsToImageInvocation.vae_decode(
context=context,
vae_info=vae_info,
seamless_axes=self.vae.seamless_axes,
latents=refined_latent_tile,
use_fp32=self.vae_fp32,
use_tiling=False,
)
refined_image_tiles.append(refined_image_tile)
# Merge the refined image tiles back into a single image.
...
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
TorchDevice.empty_cache()
name = context.tensors.save(tensor=result_latents)
return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)