WIP - Started working towards MultiDiffusion batching.

This commit is contained in:
Ryan Dick 2024-06-18 14:35:41 -04:00
parent b1bb1511fe
commit 6bcf48aa37

View File

@ -15,6 +15,10 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
from invokeai.backend.tiles.utils import TBLR from invokeai.backend.tiles.utils import TBLR
# The maximum number of regions with compatible sizes that will be batched together.
# Larger batch sizes improve speed, but require more device memory.
MAX_REGION_BATCH_SIZE = 4
@dataclass @dataclass
class MultiDiffusionRegionConditioning: class MultiDiffusionRegionConditioning:
@ -27,6 +31,38 @@ class MultiDiffusionRegionConditioning:
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline): class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising.""" """A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
def _split_into_region_batches(
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
) -> list[list[MultiDiffusionRegionConditioning]]:
# Group the regions by shape. Only regions with the same shape can be batched together.
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
for region_conditioning in multi_diffusion_conditioning:
shape_hw = (
region_conditioning.region.bottom - region_conditioning.region.top,
region_conditioning.region.right - region_conditioning.region.left,
)
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
if shape_hw not in conditioning_by_shape:
conditioning_by_shape[shape_hw] = []
conditioning_by_shape[shape_hw].append(region_conditioning)
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
region_conditioning_batches = []
for region_conditioning_batch in conditioning_by_shape.values():
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
return region_conditioning_batches
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
"""Check the input conditioning and confirm that regional prompting is not used."""
for region_conditioning in multi_diffusion_conditioning:
if (
region_conditioning.text_conditioning_data.cond_regions is not None
or region_conditioning.text_conditioning_data.uncond_regions is not None
):
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
def multi_diffusion_denoise( def multi_diffusion_denoise(
self, self,
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning], multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
@ -37,6 +73,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
init_timestep: torch.Tensor, init_timestep: torch.Tensor,
callback: Callable[[PipelineIntermediateState], None], callback: Callable[[PipelineIntermediateState], None],
) -> torch.Tensor: ) -> torch.Tensor:
self._check_regional_prompting(multi_diffusion_conditioning)
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle # TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
# cases where densoisings_start and denoising_end are set such that there are no timesteps. # cases where densoisings_start and denoising_end are set such that there are no timesteps.
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0: if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
@ -57,7 +95,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
# Populate a weighted mask that will be used to combine the results from each region after every step. # Populate a weighted mask that will be used to combine the results from each region after every step.
# For now, we assume that each regions has the same weight (1.0). # For now, we assume that each region has the same weight (1.0).
region_weight_mask = torch.zeros( region_weight_mask = torch.zeros(
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype (1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
) )
@ -65,11 +103,15 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
region = region_conditioning.region region = region_conditioning.region
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0 region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
# Group the region conditioning into batches for faster processing.
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since # Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a # we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
# separate scheduler state for each region batch. # separate scheduler state for each region batch.
region_batch_schedulers: list[SchedulerMixin] = [ region_batch_schedulers: list[SchedulerMixin] = [
copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
] ]
callback( callback(
@ -87,20 +129,52 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
merged_latents = torch.zeros_like(latents) merged_latents = torch.zeros_like(latents)
merged_pred_original: torch.Tensor | None = None merged_pred_original: torch.Tensor | None = None
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning): for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
# Switch to the scheduler for the region batch. # Switch to the scheduler for the region batch.
self.scheduler = region_batch_schedulers[region_idx] self.scheduler = region_batch_schedulers[region_batch_idx]
# Run a denoising step on the region. # TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
step_output = self._region_step(
region_conditioning=region_conditioning, # Prepare the latents for the region batch.
t=batched_t, batch_latents = torch.cat(
latents=latents, [
step_index=i, latents[
total_step_count=len(timesteps), :,
scheduler_step_kwargs=scheduler_step_kwargs, :,
region_conditioning.region.top : region_conditioning.region.bottom,
region_conditioning.region.left : region_conditioning.region.right,
]
for region_conditioning in region_conditioning_batch
],
) )
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
# handle broadcasting properly?
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# Run the denoising step on the region.
step_output = self.step(
t=batched_t,
latents=batch_latents,
conditioning_data=region_conditioning.text_conditioning_data,
step_index=i,
total_step_count=total_step_count,
scheduler_step_kwargs=scheduler_step_kwargs,
mask_guidance=None,
mask=None,
masked_latents=None,
control_data=region_conditioning.control_data,
)
# Run a denoising step on the region.
# step_output = self._region_step(
# region_conditioning=region_conditioning,
# t=batched_t,
# latents=latents,
# step_index=i,
# total_step_count=len(timesteps),
# scheduler_step_kwargs=scheduler_step_kwargs,
# )
# Store the results from the region. # Store the results from the region.
region = region_conditioning.region region = region_conditioning.region
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
@ -136,7 +210,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
return latents return latents
@torch.inference_mode() @torch.inference_mode()
def _region_step( def _region_batch_step(
self, self,
region_conditioning: MultiDiffusionRegionConditioning, region_conditioning: MultiDiffusionRegionConditioning,
t: torch.Tensor, t: torch.Tensor,
@ -145,13 +219,6 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
total_step_count: int, total_step_count: int,
scheduler_step_kwargs: dict[str, Any], scheduler_step_kwargs: dict[str, Any],
): ):
use_regional_prompting = (
region_conditioning.text_conditioning_data.cond_regions is not None
or region_conditioning.text_conditioning_data.uncond_regions is not None
)
if use_regional_prompting:
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
# Crop the inputs to the region. # Crop the inputs to the region.
region_latents = latents[ region_latents = latents[
:, :,