From d5ab8cab5c8965616a640730244c00a0a043cdac Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 7 Jun 2024 12:06:35 -0400 Subject: [PATCH] WIP - TiledStableDiffusionRefine --- .../tiled_stable_diffusion_refine.py | 102 +++++++++++++++++- 1 file changed, 97 insertions(+), 5 deletions(-) diff --git a/invokeai/app/invocations/tiled_stable_diffusion_refine.py b/invokeai/app/invocations/tiled_stable_diffusion_refine.py index 0fac95c16c..2b983c14a7 100644 --- a/invokeai/app/invocations/tiled_stable_diffusion_refine.py +++ b/invokeai/app/invocations/tiled_stable_diffusion_refine.py @@ -1,9 +1,11 @@ import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from PIL import Image from pydantic import field_validator from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation from invokeai.app.invocations.constants import DEFAULT_PRECISION, LATENT_SCALE_FACTOR, SCHEDULER_NAME_VALUES +from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler from invokeai.app.invocations.fields import ( ConditioningField, FieldDescriptions, @@ -14,13 +16,14 @@ from invokeai.app.invocations.fields import ( UIType, ) from invokeai.app.invocations.image_to_latents import ImageToLatentsInvocation -from invokeai.app.invocations.denoise_latents import DenoiseLatentsInvocation, get_scheduler +from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation from invokeai.app.invocations.model import UNetField, VAEField from invokeai.app.invocations.noise import get_noise from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.stable_diffusion.diffusers_pipeline import image_resized_to_grid_as_tensor from invokeai.backend.tiles.tiles import calc_tiles_min_overlap +from invokeai.backend.tiles.utils import Tile from invokeai.backend.util.devices import TorchDevice @@ -106,6 +109,25 @@ class TiledStableDiffusionRefine(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v + @staticmethod + def crop_latents_to_tile(latents: torch.Tensor, image_tile: Tile) -> torch.Tensor: + """Crop the latent-space tensor to the area corresponding to the image-space tile. + The tile coordinates must be divisible by the LATENT_SCALE_FACTOR. + """ + for coord in [image_tile.coords.top, image_tile.coords.left, image_tile.coords.right, image_tile.coords.bottom]: + if coord % LATENT_SCALE_FACTOR != 0: + raise ValueError( + f"The tile coordinates must all be divisible by the latent scale factor" + f" ({LATENT_SCALE_FACTOR}). {image_tile.coords=}." + ) + assert latents.dim == 4 # We expect: (batch_size, channels, height, width). + + top = image_tile.coords.top // LATENT_SCALE_FACTOR + left = image_tile.coords.left // LATENT_SCALE_FACTOR + bottom = image_tile.coords.bottom // LATENT_SCALE_FACTOR + right = image_tile.coords.right // LATENT_SCALE_FACTOR + return latents[..., top:bottom, left:right] + @torch.no_grad() def invoke(self, context: InvocationContext) -> LatentsOutput: # TODO(ryand): Expose the seed parameter. @@ -141,7 +163,8 @@ class TiledStableDiffusionRefine(BaseInvocation): image_tiles.append(image_tile) # VAE-encode each image tile independently. - # TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? + # TODO(ryand): Is there any advantage to VAE-encoding the entire image before splitting it into tiles? What + # about for decoding? vae_info = context.models.load(self.vae.vae) latent_tiles: list[torch.Tensor] = [] for image_tile in image_tiles: @@ -157,7 +180,7 @@ class TiledStableDiffusionRefine(BaseInvocation): # noise. assert input_image_torch.shape[2] % LATENT_SCALE_FACTOR == 0 assert input_image_torch.shape[3] % LATENT_SCALE_FACTOR == 0 - noise_tiles = get_noise( + global_noise = get_noise( width=input_image_torch.shape[3], height=input_image_torch.shape[2], device=TorchDevice.choose_torch_device(), @@ -166,6 +189,9 @@ class TiledStableDiffusionRefine(BaseInvocation): use_cpu=True, ) + # Crop the global noise into tiles. + noise_tiles = [self.crop_latents_to_tile(latents=global_noise, image_tile=t) for t in tiles] + # Load the UNet model. unet_info = context.models.load(self.unet.unet) @@ -178,10 +204,76 @@ class TiledStableDiffusionRefine(BaseInvocation): scheduler_name=self.scheduler, seed=seed, ) - pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler) - for latent_tile in latent_tiles: + pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler) + # Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles. + # Assume that all tiles have the same shape. + _, _, latent_height, latent_width = latent_tiles[0].shape + conditioning_data = DenoiseLatentsInvocation.get_conditioning_data( + context=context, + positive_conditioning_field=self.positive_conditioning, + negative_conditioning_field=self.negative_conditioning, + unet=unet, + latent_height=latent_height, + latent_width=latent_width, + cfg_scale=self.cfg_scale, + steps=self.steps, + cfg_rescale_multiplier=self.cfg_rescale_multiplier, + ) + # Denoise (i.e. "refine") each tile independently. + for latent_tile, noise_tile in zip(latent_tiles, noise_tiles, strict=True): + assert latent_tile.shape == noise_tile.shape + + num_inference_steps, timesteps, init_timestep, scheduler_step_kwargs = ( + DenoiseLatentsInvocation.init_scheduler( + scheduler, + device=unet.device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + seed=seed, + ) + ) + + refined_latent_tile = pipeline.latents_from_embeddings( + latents=latent_tile, + timesteps=timesteps, + init_timestep=init_timestep, + noise=noise_tile, + seed=seed, + mask=None, + masked_latents=None, + gradient_mask=None, + num_inference_steps=num_inference_steps, + scheduler_step_kwargs=scheduler_step_kwargs, + conditioning_data=conditioning_data, + control_data=None, + ip_adapter_data=None, + t2i_adapter_data=None, + callback=lambda x: None, + ) + refined_latent_tiles.append(refined_latent_tile) + + # VAE-decode each refined latent tile independently. + refined_image_tiles: list[Image.Image] = [] + for refined_latent_tile in refined_latent_tiles: + refined_image_tile = LatentsToImageInvocation.vae_decode( + context=context, + vae_info=vae_info, + seamless_axes=self.vae.seamless_axes, + latents=refined_latent_tile, + use_fp32=self.vae_fp32, + use_tiling=False, + ) + refined_image_tiles.append(refined_image_tile) + + # Merge the refined image tiles back into a single image. + ... + + # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699 + result_latents = result_latents.to("cpu") + TorchDevice.empty_cache() name = context.tensors.save(tensor=result_latents) return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None)