diff --git a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py index 5d408a4df7..edb7691e4d 100644 --- a/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py +++ b/invokeai/app/invocations/tiled_multi_diffusion_denoise_latents.py @@ -175,6 +175,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation): _, _, latent_height, latent_width = latents.shape # Calculate the tile locations to cover the latent-space image. + # TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider: + # - How much overlap 'context' to provide for each denoising step. + # - How much overlap to use during merging/blending. + # - Should we 'jitter' the tile locations in each step so that the seams are in different places? tiles = calc_tiles_min_overlap( image_height=latent_height, image_width=latent_width, diff --git a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py index 0ddcfdd380..6c07fc1c2c 100644 --- a/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py +++ b/invokeai/backend/stable_diffusion/multi_diffusion_pipeline.py @@ -61,6 +61,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): # full noise. Investigate the history of why this got commented out. # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers latents = self.scheduler.add_noise(latents, noise, batched_init_timestep) + assert isinstance(latents, torch.Tensor) # For static type checking. # TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after # cropping into regions. @@ -122,19 +123,42 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): control_data=region_conditioning.control_data, ) - # Store the results from the region. - # If two tiles overlap by more than the target overlap amount, crop the left and top edges of the - # affected tiles to achieve the target overlap. + # Build a region_weight matrix that applies gradient blending to the edges of the region. region = region_conditioning.region - top_adjustment = max(0, region.overlap.top - target_overlap) - left_adjustment = max(0, region.overlap.left - target_overlap) - region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom) - region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right) - merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[ - :, :, top_adjustment:, left_adjustment: - ] - # For now, we treat every region as having the same weight. - merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0 + _, _, region_height, region_width = step_output.prev_sample.shape + region_weight = torch.ones( + (1, 1, region_height, region_width), + dtype=latents.dtype, + device=latents.device, + ) + if region.overlap.left > 0: + left_grad = torch.linspace( + 0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype + ).view((1, 1, 1, -1)) + region_weight[:, :, :, : region.overlap.left] *= left_grad + if region.overlap.top > 0: + top_grad = torch.linspace( + 0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype + ).view((1, 1, -1, 1)) + region_weight[:, :, : region.overlap.top, :] *= top_grad + if region.overlap.right > 0: + right_grad = torch.linspace( + 1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype + ).view((1, 1, 1, -1)) + region_weight[:, :, :, -region.overlap.right :] *= right_grad + if region.overlap.bottom > 0: + bottom_grad = torch.linspace( + 1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype + ).view((1, 1, -1, 1)) + region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad + + # Update the merged results with the region results. + merged_latents[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += step_output.prev_sample * region_weight + merged_latents_weights[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += region_weight pred_orig_sample = getattr(step_output, "pred_original_sample", None) if pred_orig_sample is not None: @@ -142,9 +166,9 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): # they all use the same scheduler. if merged_pred_original is None: merged_pred_original = torch.zeros_like(latents) - merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[ - :, :, top_adjustment:, left_adjustment: - ] + merged_pred_original[ + :, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right + ] += pred_orig_sample # Normalize the merged results. latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)