From e1af78c70204c388ca11858bd169b69ed12908b6 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Tue, 25 Jun 2024 09:57:40 -0400 Subject: [PATCH] Make the tile_overlap input to MultiDiffusion *strictly* control the amount of overlap rather than being a lower bound. --- .../tiled_multi_diffusion_denoise_latents.py | 11 +++-- .../multi_diffusion_pipeline.py | 46 +++++++++++-------- 2 files changed, 33 insertions(+), 24 deletions(-) diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 717f8e1019..bb548e3149 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -86,11 +86,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): ) tile_height: int = InputField(default=64, gt=0, description="Height of the tiles in latent space.") tile_width: int = InputField(default=64, gt=0, description="Width of the tiles in latent space.") - tile_min_overlap: int = InputField( + tile_overlap: int = InputField( default=16, gt=0, - description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than " - "this to evenly cover the entire image.", + description="The overlap between adjacent tiles in latent space. Tiles will be cropped during merging " + "(if necessary) to ensure that they overlap by exactly this amount.", ) steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps) cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale") @@ -167,7 +167,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): image_width=latent_width, tile_height=self.tile_height, tile_width=self.tile_width, - min_overlap=self.tile_min_overlap, + min_overlap=self.tile_overlap, ) # Get the unet's config so that we can pass the base to sd_step_callback(). @@ -234,7 +234,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True): multi_diffusion_conditioning.append( MultiDiffusionRegionConditioning( - region=tile.coords, + region=tile, text_conditioning_data=conditioning_data, control_data=tile_controlnet_data, ) @@ -252,6 +252,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): # Run Multi-Diffusion denoising. result_latents = pipeline.multi_diffusion_denoise( multi_diffusion_conditioning=multi_diffusion_conditioning, + target_overlap=self.tile_overlap, latents=latents, scheduler_step_kwargs=scheduler_step_kwargs, noise=noise, diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index b6da66de00..8036ca3e01 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -13,13 +13,13 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import ( StableDiffusionGeneratorPipeline, ) from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData -from invokeai.backend.tiles.utils import TBLR +from invokeai.backend.tiles.utils import Tile @dataclass class MultiDiffusionRegionConditioning: # Region coords in latent space. - region: TBLR + region: Tile text_conditioning_data: TextConditioningData control_data: list[ControlNetData] @@ -39,6 +39,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): def multi_diffusion_denoise( self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning], + target_overlap: int, latents: torch.Tensor, scheduler_step_kwargs: dict[str, Any], noise: Optional[torch.Tensor], @@ -66,15 +67,6 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): # cropping into regions. self._adjust_memory_efficient_attention(latents) - # Populate a weighted mask that will be used to combine the results from each region after every step. - # For now, we assume that each region has the same weight (1.0). - region_weight_mask = torch.zeros( - (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype - ) - for region_conditioning in multi_diffusion_conditioning: - region = region_conditioning.region - region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0 - # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a # separate scheduler state for each region batch. @@ -101,6 +93,9 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): batched_t = t.expand(batch_size) merged_latents = torch.zeros_like(latents) + merged_latents_weights = torch.zeros( + (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype + ) merged_pred_original: torch.Tensor | None = None for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): # Switch to the scheduler for the region batch. @@ -110,8 +105,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): region_latents = latents[ :, :, - region_conditioning.region.top : region_conditioning.region.bottom, - region_conditioning.region.left : region_conditioning.region.right, + region_conditioning.region.coords.top : region_conditioning.region.coords.bottom, + region_conditioning.region.coords.left : region_conditioning.region.coords.right, ] # Run the denoising step on the region. @@ -129,24 +124,37 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): ) # Store the results from the region. + # If two tiles overlap by more than the target overlap amount, crop the left and top edges of the + # affected tiles to achieve the target overlap. region = region_conditioning.region - merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample + top_adjustment = max(0, region.overlap.top - target_overlap) + left_adjustment = max(0, region.overlap.left - target_overlap) + region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom) + region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right) + merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[ + :, :, top_adjustment:, left_adjustment: + ] + # For now, we treat every region as having the same weight. + merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0 + pred_orig_sample = getattr(step_output, "pred_original_sample", None) if pred_orig_sample is not None: # If one region has pred_original_sample, then we can assume that all regions will have it, because # they all use the same scheduler. if merged_pred_original is None: merged_pred_original = torch.zeros_like(latents) - merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += ( - pred_orig_sample - ) + merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[ + :, :, top_adjustment:, left_adjustment: + ] # Normalize the merged results. - latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents) + latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents) + # For debugging, uncomment this line to visualize the region seams: + # latents = torch.where(merged_latents_weights > 1, 0.0, latents) predicted_original = None if merged_pred_original is not None: predicted_original = torch.where( - region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original + merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original ) callback(