mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add gradient blending to tile seams in MultiDiffusion.
This commit is contained in:
parent
97a7f51721
commit
e16faa6370
@ -175,6 +175,10 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
||||
_, _, latent_height, latent_width = latents.shape
|
||||
|
||||
# Calculate the tile locations to cover the latent-space image.
|
||||
# TODO(ryand): In the future, we may want to revisit the tile overlap strategy. Things to consider:
|
||||
# - How much overlap 'context' to provide for each denoising step.
|
||||
# - How much overlap to use during merging/blending.
|
||||
# - Should we 'jitter' the tile locations in each step so that the seams are in different places?
|
||||
tiles = calc_tiles_min_overlap(
|
||||
image_height=latent_height,
|
||||
image_width=latent_width,
|
||||
|
@ -61,6 +61,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
# full noise. Investigate the history of why this got commented out.
|
||||
# latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers
|
||||
latents = self.scheduler.add_noise(latents, noise, batched_init_timestep)
|
||||
assert isinstance(latents, torch.Tensor) # For static type checking.
|
||||
|
||||
# TODO(ryand): Look into the implications of passing in latents here that are larger than they will be after
|
||||
# cropping into regions.
|
||||
@ -122,19 +123,42 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
control_data=region_conditioning.control_data,
|
||||
)
|
||||
|
||||
# Store the results from the region.
|
||||
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
|
||||
# affected tiles to achieve the target overlap.
|
||||
# Build a region_weight matrix that applies gradient blending to the edges of the region.
|
||||
region = region_conditioning.region
|
||||
top_adjustment = max(0, region.overlap.top - target_overlap)
|
||||
left_adjustment = max(0, region.overlap.left - target_overlap)
|
||||
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
|
||||
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
|
||||
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
|
||||
:, :, top_adjustment:, left_adjustment:
|
||||
]
|
||||
# For now, we treat every region as having the same weight.
|
||||
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
|
||||
_, _, region_height, region_width = step_output.prev_sample.shape
|
||||
region_weight = torch.ones(
|
||||
(1, 1, region_height, region_width),
|
||||
dtype=latents.dtype,
|
||||
device=latents.device,
|
||||
)
|
||||
if region.overlap.left > 0:
|
||||
left_grad = torch.linspace(
|
||||
0, 1, region.overlap.left, device=latents.device, dtype=latents.dtype
|
||||
).view((1, 1, 1, -1))
|
||||
region_weight[:, :, :, : region.overlap.left] *= left_grad
|
||||
if region.overlap.top > 0:
|
||||
top_grad = torch.linspace(
|
||||
0, 1, region.overlap.top, device=latents.device, dtype=latents.dtype
|
||||
).view((1, 1, -1, 1))
|
||||
region_weight[:, :, : region.overlap.top, :] *= top_grad
|
||||
if region.overlap.right > 0:
|
||||
right_grad = torch.linspace(
|
||||
1, 0, region.overlap.right, device=latents.device, dtype=latents.dtype
|
||||
).view((1, 1, 1, -1))
|
||||
region_weight[:, :, :, -region.overlap.right :] *= right_grad
|
||||
if region.overlap.bottom > 0:
|
||||
bottom_grad = torch.linspace(
|
||||
1, 0, region.overlap.bottom, device=latents.device, dtype=latents.dtype
|
||||
).view((1, 1, -1, 1))
|
||||
region_weight[:, :, -region.overlap.bottom :, :] *= bottom_grad
|
||||
|
||||
# Update the merged results with the region results.
|
||||
merged_latents[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += step_output.prev_sample * region_weight
|
||||
merged_latents_weights[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += region_weight
|
||||
|
||||
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||
if pred_orig_sample is not None:
|
||||
@ -142,9 +166,9 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||
# they all use the same scheduler.
|
||||
if merged_pred_original is None:
|
||||
merged_pred_original = torch.zeros_like(latents)
|
||||
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
|
||||
:, :, top_adjustment:, left_adjustment:
|
||||
]
|
||||
merged_pred_original[
|
||||
:, :, region.coords.top : region.coords.bottom, region.coords.left : region.coords.right
|
||||
] += pred_orig_sample
|
||||
|
||||
# Normalize the merged results.
|
||||
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
|
||||
|
Loading…
Reference in New Issue
Block a user