mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Make the tile_overlap input to MultiDiffusion *strictly* control the amount of overlap rather than being a lower bound.
This commit is contained in:
parent
c5588e1ff7
commit
e1af78c702
@ -86,11 +86,11 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
)
|
)
|
||||||
tile_height: int = InputField(default=64, gt=0, description="Height of the tiles in latent space.")
|
tile_height: int = InputField(default=64, gt=0, description="Height of the tiles in latent space.")
|
||||||
tile_width: int = InputField(default=64, gt=0, description="Width of the tiles in latent space.")
|
tile_width: int = InputField(default=64, gt=0, description="Width of the tiles in latent space.")
|
||||||
tile_min_overlap: int = InputField(
|
tile_overlap: int = InputField(
|
||||||
default=16,
|
default=16,
|
||||||
gt=0,
|
gt=0,
|
||||||
description="The minimum overlap between adjacent tiles in latent space. The actual overlap may be larger than "
|
description="The overlap between adjacent tiles in latent space. Tiles will be cropped during merging "
|
||||||
"this to evenly cover the entire image.",
|
"(if necessary) to ensure that they overlap by exactly this amount.",
|
||||||
)
|
)
|
||||||
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
steps: int = InputField(default=18, gt=0, description=FieldDescriptions.steps)
|
||||||
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
cfg_scale: float | list[float] = InputField(default=6.0, description=FieldDescriptions.cfg_scale, title="CFG Scale")
|
||||||
@ -167,7 +167,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
image_width=latent_width,
|
image_width=latent_width,
|
||||||
tile_height=self.tile_height,
|
tile_height=self.tile_height,
|
||||||
tile_width=self.tile_width,
|
tile_width=self.tile_width,
|
||||||
min_overlap=self.tile_min_overlap,
|
min_overlap=self.tile_overlap,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get the unet's config so that we can pass the base to sd_step_callback().
|
# Get the unet's config so that we can pass the base to sd_step_callback().
|
||||||
@ -234,7 +234,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
|
for tile, tile_controlnet_data in zip(tiles, controlnet_data_tiles, strict=True):
|
||||||
multi_diffusion_conditioning.append(
|
multi_diffusion_conditioning.append(
|
||||||
MultiDiffusionRegionConditioning(
|
MultiDiffusionRegionConditioning(
|
||||||
region=tile.coords,
|
region=tile,
|
||||||
text_conditioning_data=conditioning_data,
|
text_conditioning_data=conditioning_data,
|
||||||
control_data=tile_controlnet_data,
|
control_data=tile_controlnet_data,
|
||||||
)
|
)
|
||||||
@ -252,6 +252,7 @@ class TiledMultiDiffusionDenoiseLatents(BaseInvocation):
|
|||||||
# Run Multi-Diffusion denoising.
|
# Run Multi-Diffusion denoising.
|
||||||
result_latents = pipeline.multi_diffusion_denoise(
|
result_latents = pipeline.multi_diffusion_denoise(
|
||||||
multi_diffusion_conditioning=multi_diffusion_conditioning,
|
multi_diffusion_conditioning=multi_diffusion_conditioning,
|
||||||
|
target_overlap=self.tile_overlap,
|
||||||
latents=latents,
|
latents=latents,
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
|
@ -13,13 +13,13 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||||
from invokeai.backend.tiles.utils import TBLR
|
from invokeai.backend.tiles.utils import Tile
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiDiffusionRegionConditioning:
|
class MultiDiffusionRegionConditioning:
|
||||||
# Region coords in latent space.
|
# Region coords in latent space.
|
||||||
region: TBLR
|
region: Tile
|
||||||
text_conditioning_data: TextConditioningData
|
text_conditioning_data: TextConditioningData
|
||||||
control_data: list[ControlNetData]
|
control_data: list[ControlNetData]
|
||||||
|
|
||||||
@ -39,6 +39,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
def multi_diffusion_denoise(
|
def multi_diffusion_denoise(
|
||||||
self,
|
self,
|
||||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||||
|
target_overlap: int,
|
||||||
latents: torch.Tensor,
|
latents: torch.Tensor,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
noise: Optional[torch.Tensor],
|
noise: Optional[torch.Tensor],
|
||||||
@ -66,15 +67,6 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
# cropping into regions.
|
# cropping into regions.
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
|
|
||||||
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
|
||||||
# For now, we assume that each region has the same weight (1.0).
|
|
||||||
region_weight_mask = torch.zeros(
|
|
||||||
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
|
||||||
)
|
|
||||||
for region_conditioning in multi_diffusion_conditioning:
|
|
||||||
region = region_conditioning.region
|
|
||||||
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
|
||||||
|
|
||||||
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||||
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||||
# separate scheduler state for each region batch.
|
# separate scheduler state for each region batch.
|
||||||
@ -101,6 +93,9 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
batched_t = t.expand(batch_size)
|
batched_t = t.expand(batch_size)
|
||||||
|
|
||||||
merged_latents = torch.zeros_like(latents)
|
merged_latents = torch.zeros_like(latents)
|
||||||
|
merged_latents_weights = torch.zeros(
|
||||||
|
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
||||||
|
)
|
||||||
merged_pred_original: torch.Tensor | None = None
|
merged_pred_original: torch.Tensor | None = None
|
||||||
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
|
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
|
||||||
# Switch to the scheduler for the region batch.
|
# Switch to the scheduler for the region batch.
|
||||||
@ -110,8 +105,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
region_latents = latents[
|
region_latents = latents[
|
||||||
:,
|
:,
|
||||||
:,
|
:,
|
||||||
region_conditioning.region.top : region_conditioning.region.bottom,
|
region_conditioning.region.coords.top : region_conditioning.region.coords.bottom,
|
||||||
region_conditioning.region.left : region_conditioning.region.right,
|
region_conditioning.region.coords.left : region_conditioning.region.coords.right,
|
||||||
]
|
]
|
||||||
|
|
||||||
# Run the denoising step on the region.
|
# Run the denoising step on the region.
|
||||||
@ -129,24 +124,37 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Store the results from the region.
|
# Store the results from the region.
|
||||||
|
# If two tiles overlap by more than the target overlap amount, crop the left and top edges of the
|
||||||
|
# affected tiles to achieve the target overlap.
|
||||||
region = region_conditioning.region
|
region = region_conditioning.region
|
||||||
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
|
top_adjustment = max(0, region.overlap.top - target_overlap)
|
||||||
|
left_adjustment = max(0, region.overlap.left - target_overlap)
|
||||||
|
region_height_slice = slice(region.coords.top + top_adjustment, region.coords.bottom)
|
||||||
|
region_width_slice = slice(region.coords.left + left_adjustment, region.coords.right)
|
||||||
|
merged_latents[:, :, region_height_slice, region_width_slice] += step_output.prev_sample[
|
||||||
|
:, :, top_adjustment:, left_adjustment:
|
||||||
|
]
|
||||||
|
# For now, we treat every region as having the same weight.
|
||||||
|
merged_latents_weights[:, :, region_height_slice, region_width_slice] += 1.0
|
||||||
|
|
||||||
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
pred_orig_sample = getattr(step_output, "pred_original_sample", None)
|
||||||
if pred_orig_sample is not None:
|
if pred_orig_sample is not None:
|
||||||
# If one region has pred_original_sample, then we can assume that all regions will have it, because
|
# If one region has pred_original_sample, then we can assume that all regions will have it, because
|
||||||
# they all use the same scheduler.
|
# they all use the same scheduler.
|
||||||
if merged_pred_original is None:
|
if merged_pred_original is None:
|
||||||
merged_pred_original = torch.zeros_like(latents)
|
merged_pred_original = torch.zeros_like(latents)
|
||||||
merged_pred_original[:, :, region.top : region.bottom, region.left : region.right] += (
|
merged_pred_original[:, :, region_height_slice, region_width_slice] += pred_orig_sample[
|
||||||
pred_orig_sample
|
:, :, top_adjustment:, left_adjustment:
|
||||||
)
|
]
|
||||||
|
|
||||||
# Normalize the merged results.
|
# Normalize the merged results.
|
||||||
latents = torch.where(region_weight_mask > 0, merged_latents / region_weight_mask, merged_latents)
|
latents = torch.where(merged_latents_weights > 0, merged_latents / merged_latents_weights, merged_latents)
|
||||||
|
# For debugging, uncomment this line to visualize the region seams:
|
||||||
|
# latents = torch.where(merged_latents_weights > 1, 0.0, latents)
|
||||||
predicted_original = None
|
predicted_original = None
|
||||||
if merged_pred_original is not None:
|
if merged_pred_original is not None:
|
||||||
predicted_original = torch.where(
|
predicted_original = torch.where(
|
||||||
region_weight_mask > 0, merged_pred_original / region_weight_mask, merged_pred_original
|
merged_latents_weights > 0, merged_pred_original / merged_latents_weights, merged_pred_original
|
||||||
)
|
)
|
||||||
|
|
||||||
callback(
|
callback(
|
||||||
|
Loading…
Reference in New Issue
Block a user