From 35adaf1c1795274e2d7b2c5ce87bcdd72e0a32e6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Mon, 17 Jun 2024 16:30:34 -0400 Subject: [PATCH] Connect TiledMultiDiffusionDenoiseLatents to the MultiDiffusionPipeline backend. --- .../tiled_multi_diffusion_denoise_latents.py | 186 ++++++++---------- .../multi_diffusion_pipeline.py | 104 +++++----- 2 files changed, 132 insertions(+), 158 deletions(-) diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 1320909436..9d998f345f 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -1,10 +1,10 @@ +import copy from contextlib import ExitStack from typing import Iterator, Tuple -import numpy as np import torch from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel -from PIL import Image +from diffusers.schedulers.scheduling_utils import SchedulerMixin from pydantic import field_validator from invokeai.app.invocations.baseinvocation import BaseInvocation, invocation @@ -19,21 +19,38 @@ from invokeai.app.invocations.fields import ( LatentsField, UIType, ) -from invokeai.app.invocations.latents_to_image import LatentsToImageInvocation from invokeai.app.invocations.model import UNetField from invokeai.app.invocations.noise import get_noise -from invokeai.app.invocations.primitives import ImageOutput +from invokeai.app.invocations.primitives import LatentsOutput from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.backend.lora import LoRAModelRaw from invokeai.backend.model_patcher import ModelPatcher from invokeai.backend.stable_diffusion.diffusers_pipeline import ControlNetData +from invokeai.backend.stable_diffusion.multi_diffusion_pipeline import ( + MultiDiffusionPipeline, + MultiDiffusionRegionConditioning, +) from invokeai.backend.tiles.tiles import ( calc_tiles_min_overlap, - merge_tiles_with_linear_blending, ) +from invokeai.backend.tiles.utils import TBLR from invokeai.backend.util.devices import TorchDevice +def crop_controlnet_data(control_data: ControlNetData, latent_region: TBLR) -> ControlNetData: + """Crop a ControlNetData object to a region.""" + # Create a shallow copy of the control_data object. + control_data_copy = copy.copy(control_data) + # The ControlNet reference image is the only attribute that needs to be cropped. + control_data_copy.image_tensor = control_data.image_tensor[ + :, + :, + latent_region.top * LATENT_SCALE_FACTOR : latent_region.bottom * LATENT_SCALE_FACTOR, + latent_region.left * LATENT_SCALE_FACTOR : latent_region.right * LATENT_SCALE_FACTOR, + ] + return control_data_copy + + @invocation( "tiled_multi_diffusion_denoise_latents", title="Tiled Multi-Diffusion Denoise Latents", @@ -119,8 +136,33 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v + @staticmethod + def create_pipeline( + unet: UNet2DConditionModel, + scheduler: SchedulerMixin, + ) -> MultiDiffusionPipeline: + # TODO(ryand): Get rid of this FakeVae hack. + class FakeVae: + class FakeVaeConfig: + def __init__(self) -> None: + self.block_out_channels = [0] + + def __init__(self) -> None: + self.config = FakeVae.FakeVaeConfig() + + return MultiDiffusionPipeline( + vae=FakeVae(), # TODO: oh... + text_encoder=None, + tokenizer=None, + unet=unet, + scheduler=scheduler, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + @torch.no_grad() - def invoke(self, context: InvocationContext) -> ImageOutput: + def invoke(self, context: InvocationContext) -> LatentsOutput: seed, noise, latents = DenoiseLatentsInvocation.prepare_noise_and_latents(context, self.noise, self.latents) _, _, latent_height, latent_width = latents.shape @@ -149,15 +191,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): min_overlap=self.tile_min_overlap, ) - # Split the noise and latents into tiles. - noise_tiles: list[torch.Tensor] = [] - latent_tiles: list[torch.Tensor] = [] - for tile in tiles: - noise_tile = noise[..., tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right] - latent_tile = latents[..., tile.coords.top : tile.coords.bottom, tile.coords.left : tile.coords.right] - noise_tiles.append(noise_tile) - latent_tiles.append(latent_tile) - # Prepare an iterator that yields the UNet's LoRA models and their weights. def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: for lora in self.unet.loras: @@ -169,7 +202,6 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): # Load the UNet model. unet_info = context.models.load(self.unet.unet) - refined_latent_tiles: list[torch.Tensor] = [] with ExitStack() as exit_stack, unet_info as unet, ModelPatcher.apply_lora_unet(unet, _lora_loader()): assert isinstance(unet, UNet2DConditionModel) scheduler = get_scheduler( @@ -178,7 +210,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): scheduler_name=self.scheduler, seed=seed, ) - pipeline = DenoiseLatentsInvocation.create_pipeline(unet=unet, scheduler=scheduler) + pipeline = self.create_pipeline(unet=unet, scheduler=scheduler) # Prepare the prompt conditioning data. The same prompt conditioning is applied to all tiles. conditioning_data = DenoiseLatentsInvocation.get_conditioning_data( @@ -203,95 +235,47 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): ) # Split the controlnet_data into tiles. - if controlnet_data is not None: - # controlnet_data_tiles[t][c] is the c'th control data for the t'th tile. - controlnet_data_tiles: list[list[ControlNetData]] = [] - for tile in tiles: - # To split the controlnet_data into tiles, we simply need to crop each image_tensor. All other - # params can be copied unmodified. - tile_controlnet_data = [ - ControlNetData( - model=cn.model, - image_tensor=cn.image_tensor[ - :, - :, - tile.coords.top * LATENT_SCALE_FACTOR : tile.coords.bottom * LATENT_SCALE_FACTOR, - tile.coords.left * LATENT_SCALE_FACTOR : tile.coords.right * LATENT_SCALE_FACTOR, - ], - weight=cn.weight, - begin_step_percent=cn.begin_step_percent, - end_step_percent=cn.end_step_percent, - control_mode=cn.control_mode, - resize_mode=cn.resize_mode, - ) - for cn in controlnet_data - ] - controlnet_data_tiles.append(tile_controlnet_data) + # controlnet_data_tiles[t][c] is the c'th control data for the t'th tile. + controlnet_data_tiles: list[list[ControlNetData]] = [] + for tile in tiles: + tile_controlnet_data = [crop_controlnet_data(cn, tile.coords) for cn in controlnet_data or []] + controlnet_data_tiles.append(tile_controlnet_data) - # Denoise (i.e. "refine") each tile independently. - for image_tile_np, latent_tile, noise_tile in zip(image_tiles_np, latent_tiles, noise_tiles, strict=True): - assert latent_tile.shape == noise_tile.shape - - # Prepare a PIL Image for ControlNet processing. - # TODO(ryand): This is a bit awkward that we have to prepare both torch.Tensor and PIL.Image versions of - # the tiles. Ideally, the ControlNet code should be able to work with Tensors. - image_tile_pil = Image.fromarray(image_tile_np) - - timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler( - scheduler, - device=unet.device, - steps=self.steps, - denoising_start=self.denoising_start, - denoising_end=self.denoising_end, - seed=seed, + # Prepare the MultiDiffusionRegionConditioning list. + multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning] = [] + for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True): + multi_diffusion_conditioning.append( + MultiDiffusionRegionConditioning( + region=tile.coords, + text_conditioning_data=conditioning_data, + control_data=tile_controlnet_data, + ) ) - # TODO(ryand): Think about when/if latents/noise should be moved off of the device to save VRAM. - latent_tile = latent_tile.to(device=unet.device, dtype=unet.dtype) - noise_tile = noise_tile.to(device=unet.device, dtype=unet.dtype) - refined_latent_tile = pipeline.latents_from_embeddings( - latents=latent_tile, - timesteps=timesteps, - init_timestep=init_timestep, - noise=noise_tile, - seed=seed, - mask=None, - masked_latents=None, - scheduler_step_kwargs=scheduler_step_kwargs, - conditioning_data=conditioning_data, - control_data=[controlnet_data], - ip_adapter_data=None, - t2i_adapter_data=None, - callback=lambda x: None, - ) - refined_latent_tiles.append(refined_latent_tile) - - # VAE-decode each refined latent tile independently. - refined_image_tiles: list[Image.Image] = [] - for refined_latent_tile in refined_latent_tiles: - refined_image_tile = LatentsToImageInvocation.vae_decode( - context=context, - vae_info=vae_info, - seamless_axes=self.vae.seamless_axes, - latents=refined_latent_tile, - use_fp32=self.vae_fp32, - use_tiling=False, + timesteps, init_timestep, scheduler_step_kwargs = DenoiseLatentsInvocation.init_scheduler( + scheduler, + device=unet.device, + steps=self.steps, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + seed=seed, + ) + + # Run Multi-Diffusion denoising. + result_latents = pipeline.multi_diffusion_denoise( + multi_diffusion_conditioning=multi_diffusion_conditioning, + latents=latents, + scheduler_step_kwargs=scheduler_step_kwargs, + noise=noise, + timesteps=timesteps, + init_timestep=init_timestep, + # TODO(ryand): Add proper callback. + callback=lambda x: None, ) - refined_image_tiles.append(refined_image_tile) # TODO(ryand): I copied this from DenoiseLatentsInvocation. I'm not sure if it's actually important. + result_latents = result_latents.to("cpu") TorchDevice.empty_cache() - # Merge the refined image tiles back into a single image. - refined_image_tiles_np = [np.array(t) for t in refined_image_tiles] - merged_image_np = np.zeros(shape=(input_image.height, input_image.width, 3), dtype=np.uint8) - # TODO(ryand): Tune the blend_amount. Should this be exposed as a parameter? - merge_tiles_with_linear_blending( - dst_image=merged_image_np, tiles=tiles, tile_images=refined_image_tiles_np, blend_amount=self.tile_overlap - ) - - # Save the refined image and return its reference. - merged_image_pil = Image.fromarray(merged_image_np) - image_dto = context.images.save(image=merged_image_pil) - - return ImageOutput.build(image_dto) + name = context.tensors.save(tensor=result_latents) + return LatentsOutput.build(latents_name=name, latents=result_latents, seed=None) diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index 435140523f..2049a19733 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -1,6 +1,6 @@ from __future__ import annotations -import copy +from dataclasses import dataclass from typing import Any, Callable, Optional import torch @@ -11,7 +11,15 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( StableDiffusionGeneratorPipeline, ) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData -from invokeai.backend.tiles.utils import Tile +from invokeai.backend.tiles.utils import TBLR + + +@dataclass +class MultiDiffusionRegionConditioning: + # Region coords in latent space. + region: TBLR + text_conditioning_data: TextConditioningData + control_data: list[ControlNetData] class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): @@ -45,15 +53,13 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): # - May need a cleaner AddsMaskGuidance implementation to handle this plan... we'll see. def multi_diffusion_denoise( self, - regions: list[Tile], + multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning], latents: torch.Tensor, scheduler_step_kwargs: dict[str, Any], - conditioning_data: TextConditioningData, noise: Optional[torch.Tensor], timesteps: torch.Tensor, init_timestep: torch.Tensor, callback: Callable[[PipelineIntermediateState], None], - control_data: list[ControlNetData] | None = None, ) -> torch.Tensor: # TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle # cases where densoisings_start and denoising_end are set such that there are no timesteps. @@ -74,21 +80,14 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): # cropping into regions. self._adjust_memory_efficient_attention(latents) - use_regional_prompting = ( - conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None - ) - if use_regional_prompting: - raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") - # Populate a weighted mask that will be used to combine the results from each region after every step. # For now, we assume that each regions has the same weight (1.0). region_weight_mask = torch.zeros( (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype ) - for region in regions: - region_weight_mask[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += 1.0 + for region_conditioning in multi_diffusion_conditioning: + region = region_conditioning.region + region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0 callback( PipelineIntermediateState( @@ -103,39 +102,36 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): for i, t in enumerate(self.progress_bar(timesteps)): batched_t = t.expand(batch_size) - prev_samples_by_region: list[torch.Tensor] = [] - pred_original_by_region: list[torch.Tensor | None] = [] - for region in regions: + merged_latents = torch.zeros_like(latents) + merged_pred_original: torch.Tensor | None = None + for region_conditioning in multi_diffusion_conditioning: # Run a denoising step on the region. step_output = self._region_step( - region=region, + region_conditioning=region_conditioning, t=batched_t, latents=latents, - conditioning_data=conditioning_data, step_index=i, total_step_count=len(timesteps), scheduler_step_kwargs=scheduler_step_kwargs, - control_data=control_data, ) - prev_samples_by_region.append(step_output.prev_sample) - pred_original_by_region.append(getattr(step_output, "pred_original_sample", None)) - # Merge the prev_sample results from each region. - merged_latents = torch.zeros_like(latents) - for region_idx, region in enumerate(regions): - merged_latents[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += prev_samples_by_region[region_idx] + # Store the results from the region. + region = region_conditioning.region + merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample + pred_orig_sample = getattr(step_output, "pred_original_sample", None) + if pred_orig_sample is not None: + # If one region has pred_original_sample, then we can assume that all regions will have it, because + # they all use the same scheduler. + if merged_pred_original is None: + merged_pred_original = torch.zeros_like(latents) + merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += ( + pred_orig_sample + ) + + # Normalize the merged results. latents = merged_latents / region_weight_mask - - # Merge the predicted_original results from each region. predicted_original = None - if all(pred_original_by_region): - merged_pred_original = torch.zeros_like(latents) - for region_idx, region in enumerate(regions): - merged_pred_original[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] += pred_original_by_region[region_idx] + if merged_pred_original is not None: predicted_original = merged_pred_original / region_weight_mask callback( @@ -154,44 +150,38 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): @torch.inference_mode() def _region_step( self, - region: Tile, + region_conditioning: MultiDiffusionRegionConditioning, t: torch.Tensor, latents: torch.Tensor, - conditioning_data: TextConditioningData, step_index: int, total_step_count: int, scheduler_step_kwargs: dict[str, Any], - control_data: list[ControlNetData] | None = None, ): + use_regional_prompting = ( + region_conditioning.text_conditioning_data.cond_regions is not None + or region_conditioning.text_conditioning_data.uncond_regions is not None + ) + if use_regional_prompting: + raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.") + # Crop the inputs to the region. region_latents = latents[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + :, + :, + region_conditioning.region.top : region_conditioning.region.bottom, + region_conditioning.region.left : region_conditioning.region.right, ] - region_control_data: list[ControlNetData] | None = None - if control_data is not None: - region_control_data = [self._crop_controlnet_data(c, region) for c in control_data] - # Run the denoising step on the region. return self.step( t=t, latents=region_latents, - conditioning_data=conditioning_data, + conditioning_data=region_conditioning.text_conditioning_data, step_index=step_index, total_step_count=total_step_count, scheduler_step_kwargs=scheduler_step_kwargs, mask_guidance=None, mask=None, masked_latents=None, - control_data=region_control_data, + control_data=region_conditioning.control_data, ) - - def _crop_controlnet_data(self, control_data: ControlNetData, region: Tile) -> ControlNetData: - """Crop a ControlNetData object to a region.""" - # Create a shallow copy of the control_data object. - control_data_copy = copy.copy(control_data) - # The ControlNet reference image is the only attribute that needs to be cropped. - control_data_copy.image_tensor = control_data.image_tensor[ - :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right - ] - return control_data_copy