mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - Started working towards MultiDiffusion batching.
This commit is contained in:
parent
b1bb1511fe
commit
6bcf48aa37
@ -15,6 +15,10 @@ from invokeai.backend.stable_diffusion.diffusers_pipeline import (
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningData
|
||||||
from invokeai.backend.tiles.utils import TBLR
|
from invokeai.backend.tiles.utils import TBLR
|
||||||
|
|
||||||
|
# The maximum number of regions with compatible sizes that will be batched together.
|
||||||
|
# Larger batch sizes improve speed, but require more device memory.
|
||||||
|
MAX_REGION_BATCH_SIZE = 4
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MultiDiffusionRegionConditioning:
|
class MultiDiffusionRegionConditioning:
|
||||||
@ -27,6 +31,38 @@ class MultiDiffusionRegionConditioning:
|
|||||||
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
||||||
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
|
"""A Stable Diffusion pipeline that uses Multi-Diffusion (https://arxiv.org/pdf/2302.08113) for denoising."""
|
||||||
|
|
||||||
|
def _split_into_region_batches(
|
||||||
|
self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]
|
||||||
|
) -> list[list[MultiDiffusionRegionConditioning]]:
|
||||||
|
# Group the regions by shape. Only regions with the same shape can be batched together.
|
||||||
|
conditioning_by_shape: dict[tuple[int, int], list[MultiDiffusionRegionConditioning]] = {}
|
||||||
|
for region_conditioning in multi_diffusion_conditioning:
|
||||||
|
shape_hw = (
|
||||||
|
region_conditioning.region.bottom - region_conditioning.region.top,
|
||||||
|
region_conditioning.region.right - region_conditioning.region.left,
|
||||||
|
)
|
||||||
|
# In python, a tuple of hashable objects is hashable, so can be used as a key in a dict.
|
||||||
|
if shape_hw not in conditioning_by_shape:
|
||||||
|
conditioning_by_shape[shape_hw] = []
|
||||||
|
conditioning_by_shape[shape_hw].append(region_conditioning)
|
||||||
|
|
||||||
|
# Split the regions into batches, respecting the MAX_REGION_BATCH_SIZE constraint.
|
||||||
|
region_conditioning_batches = []
|
||||||
|
for region_conditioning_batch in conditioning_by_shape.values():
|
||||||
|
for i in range(0, len(region_conditioning_batch), MAX_REGION_BATCH_SIZE):
|
||||||
|
region_conditioning_batches.append(region_conditioning_batch[i : i + MAX_REGION_BATCH_SIZE])
|
||||||
|
|
||||||
|
return region_conditioning_batches
|
||||||
|
|
||||||
|
def _check_regional_prompting(self, multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning]):
|
||||||
|
"""Check the input conditioning and confirm that regional prompting is not used."""
|
||||||
|
for region_conditioning in multi_diffusion_conditioning:
|
||||||
|
if (
|
||||||
|
region_conditioning.text_conditioning_data.cond_regions is not None
|
||||||
|
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
||||||
|
):
|
||||||
|
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
||||||
|
|
||||||
def multi_diffusion_denoise(
|
def multi_diffusion_denoise(
|
||||||
self,
|
self,
|
||||||
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
multi_diffusion_conditioning: list[MultiDiffusionRegionConditioning],
|
||||||
@ -37,6 +73,8 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
init_timestep: torch.Tensor,
|
init_timestep: torch.Tensor,
|
||||||
callback: Callable[[PipelineIntermediateState], None],
|
callback: Callable[[PipelineIntermediateState], None],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
self._check_regional_prompting(multi_diffusion_conditioning)
|
||||||
|
|
||||||
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
# TODO(ryand): Figure out why this condition is necessary, and document it. My guess is that it's to handle
|
||||||
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
# cases where densoisings_start and denoising_end are set such that there are no timesteps.
|
||||||
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
if init_timestep.shape[0] == 0 or timesteps.shape[0] == 0:
|
||||||
@ -57,7 +95,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
|
|
||||||
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
# Populate a weighted mask that will be used to combine the results from each region after every step.
|
||||||
# For now, we assume that each regions has the same weight (1.0).
|
# For now, we assume that each region has the same weight (1.0).
|
||||||
region_weight_mask = torch.zeros(
|
region_weight_mask = torch.zeros(
|
||||||
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
(1, 1, latent_height, latent_width), device=latents.device, dtype=latents.dtype
|
||||||
)
|
)
|
||||||
@ -65,11 +103,15 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
region = region_conditioning.region
|
region = region_conditioning.region
|
||||||
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
region_weight_mask[:, :, region.top : region.bottom, region.left : region.right] += 1.0
|
||||||
|
|
||||||
|
# Group the region conditioning into batches for faster processing.
|
||||||
|
# region_conditioning_batches[b][r] is the r'th region in the b'th batch.
|
||||||
|
region_conditioning_batches = self._split_into_region_batches(multi_diffusion_conditioning)
|
||||||
|
|
||||||
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
# Many of the diffusers schedulers are stateful (i.e. they update internal state in each call to step()). Since
|
||||||
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
# we are calling step() multiple times at the same timestep (once for each region batch), we must maintain a
|
||||||
# separate scheduler state for each region batch.
|
# separate scheduler state for each region batch.
|
||||||
region_batch_schedulers: list[SchedulerMixin] = [
|
region_batch_schedulers: list[SchedulerMixin] = [
|
||||||
copy.deepcopy(self.scheduler) for _ in multi_diffusion_conditioning
|
copy.deepcopy(self.scheduler) for _ in region_conditioning_batches
|
||||||
]
|
]
|
||||||
|
|
||||||
callback(
|
callback(
|
||||||
@ -87,20 +129,52 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
|
|
||||||
merged_latents = torch.zeros_like(latents)
|
merged_latents = torch.zeros_like(latents)
|
||||||
merged_pred_original: torch.Tensor | None = None
|
merged_pred_original: torch.Tensor | None = None
|
||||||
for region_idx, region_conditioning in enumerate(multi_diffusion_conditioning):
|
for region_batch_idx, region_conditioning_batch in enumerate(region_conditioning_batches):
|
||||||
# Switch to the scheduler for the region batch.
|
# Switch to the scheduler for the region batch.
|
||||||
self.scheduler = region_batch_schedulers[region_idx]
|
self.scheduler = region_batch_schedulers[region_batch_idx]
|
||||||
|
|
||||||
# Run a denoising step on the region.
|
# TODO(ryand): This logic has not yet been tested with input latents with a batch_size > 1.
|
||||||
step_output = self._region_step(
|
|
||||||
region_conditioning=region_conditioning,
|
# Prepare the latents for the region batch.
|
||||||
t=batched_t,
|
batch_latents = torch.cat(
|
||||||
latents=latents,
|
[
|
||||||
step_index=i,
|
latents[
|
||||||
total_step_count=len(timesteps),
|
:,
|
||||||
scheduler_step_kwargs=scheduler_step_kwargs,
|
:,
|
||||||
|
region_conditioning.region.top : region_conditioning.region.bottom,
|
||||||
|
region_conditioning.region.left : region_conditioning.region.right,
|
||||||
|
]
|
||||||
|
for region_conditioning in region_conditioning_batch
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO(ryand): Do we have to repeat the text_conditioning_data to match the batch size? Or does step()
|
||||||
|
# handle broadcasting properly?
|
||||||
|
|
||||||
|
# TODO(ryand): Resume here!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||||
|
# Run the denoising step on the region.
|
||||||
|
step_output = self.step(
|
||||||
|
t=batched_t,
|
||||||
|
latents=batch_latents,
|
||||||
|
conditioning_data=region_conditioning.text_conditioning_data,
|
||||||
|
step_index=i,
|
||||||
|
total_step_count=total_step_count,
|
||||||
|
scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
mask_guidance=None,
|
||||||
|
mask=None,
|
||||||
|
masked_latents=None,
|
||||||
|
control_data=region_conditioning.control_data,
|
||||||
|
)
|
||||||
|
# Run a denoising step on the region.
|
||||||
|
# step_output = self._region_step(
|
||||||
|
# region_conditioning=region_conditioning,
|
||||||
|
# t=batched_t,
|
||||||
|
# latents=latents,
|
||||||
|
# step_index=i,
|
||||||
|
# total_step_count=len(timesteps),
|
||||||
|
# scheduler_step_kwargs=scheduler_step_kwargs,
|
||||||
|
# )
|
||||||
|
|
||||||
# Store the results from the region.
|
# Store the results from the region.
|
||||||
region = region_conditioning.region
|
region = region_conditioning.region
|
||||||
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
|
merged_latents[:, :, region.top : region.bottom, region.left : region.right] += step_output.prev_sample
|
||||||
@ -136,7 +210,7 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _region_step(
|
def _region_batch_step(
|
||||||
self,
|
self,
|
||||||
region_conditioning: MultiDiffusionRegionConditioning,
|
region_conditioning: MultiDiffusionRegionConditioning,
|
||||||
t: torch.Tensor,
|
t: torch.Tensor,
|
||||||
@ -145,13 +219,6 @@ class MultiDiffusionPipeline(StableDiffusionGeneratorPipeline):
|
|||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
scheduler_step_kwargs: dict[str, Any],
|
scheduler_step_kwargs: dict[str, Any],
|
||||||
):
|
):
|
||||||
use_regional_prompting = (
|
|
||||||
region_conditioning.text_conditioning_data.cond_regions is not None
|
|
||||||
or region_conditioning.text_conditioning_data.uncond_regions is not None
|
|
||||||
)
|
|
||||||
if use_regional_prompting:
|
|
||||||
raise NotImplementedError("Regional prompting is not yet supported in Multi-Diffusion.")
|
|
||||||
|
|
||||||
# Crop the inputs to the region.
|
# Crop the inputs to the region.
|
||||||
region_latents = latents[
|
region_latents = latents[
|
||||||
:,
|
:,
|
||||||
|
Loading…
Reference in New Issue
Block a user