Implement (very slow) self-attention regional masking.

This commit is contained in:
Ryan Dick 2024-03-01 18:43:32 -05:00
parent ffc4ebb14c
commit 942efa011e
3 changed files with 195 additions and 60 deletions

View File

@ -92,15 +92,28 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# End unmodified block from AttnProcessor2_0.
# Handle regional prompt attention masks.
if is_cross_attention and regional_prompt_data is not None:
if regional_prompt_data is not None:
_, query_seq_len, _ = hidden_states.shape
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device
)
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
if is_cross_attention:
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
query_seq_len=query_seq_len, key_seq_len=sequence_length
)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=hidden_states.dtype, device=hidden_states.device
)
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(query_seq_len=query_seq_len)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=hidden_states.dtype, device=hidden_states.device
)
# prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5
# prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
if attention_mask is None:
attention_mask = prompt_region_attention_mask

View File

@ -5,60 +5,67 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
TextConditioningRegions,
)
# Stages:
# - Convert image masks to spatial masks at all downsampling factors.
# - Decision: Max pooling? Nearest? Other?
# - Should definitely be shared across all denoising steps - that should be easy.
# - Convert spatial masks to cross-attention masks.
# - This should ideally be shared across all denoising steps, but preparing the masks requires knowing the max_key_seq_len.
# - Could create it just-in-time and them cache the result
# - Convert spatial masks to self-attention masks.
# - This should be shared across all denoising steps.
# - Shape depends only on spatial resolution and downsampling factors.
# - Convert cross-attention binary mask to score mask.
# - Convert self-attention binary mask to score mask.
#
# If we wanted a time schedule for level of attenuation, we would apply that in the attention layer.
# Pre-compute the spatial masks, because that's easy.
# Compute the other stuff as it's requested. Add caching if we find that it's slow.
class RegionalPromptData:
def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]):
self._attn_masks_by_seq_len = attn_masks_by_seq_len
@classmethod
def from_regions(
cls,
regions: list[TextConditioningRegions],
key_seq_len: int,
# TODO(ryand): Pass in a list of downscale factors?
max_downscale_factor: int = 8,
):
"""Construct a `RegionalPromptData` object.
def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8):
"""Initialize a `RegionalPromptData` object.
Args:
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
batch.
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
explicitly to be sure.
max_downscale_factor: The maximum downscale factor to use when preparing the spatial masks.
"""
attn_masks_by_seq_len = {}
self._regions = regions
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
# sequence length of s.
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
regions, max_downscale_factor
)
def _prepare_spatial_masks(
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
) -> list[dict[int, torch.Tensor]]:
"""Prepare the spatial masks for all downscaling factors."""
# TODO(ryand): Pass in a list of downscale factors? IIRC, SDXL does not apply attention at all downscaling
# levels, but I need to double check that.
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
# of s.
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
# length of s.
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
for batch_sample_regions in regions:
batch_attn_masks_by_seq_len.append({})
batch_sample_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied.
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
batch_sample_masks = batch_sample_regions.masks.to(dtype=torch.float32)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
_, num_prompts, h, w = batch_masks.shape
b, num_prompts, h, w = batch_sample_masks.shape
assert b == 1
query_seq_len = h * w
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
# `encoder_hidden_states`.
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
# TODO(ryand): What device / dtype should this be?
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
:, prompt_idx, :, :
]
batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
@ -66,23 +73,86 @@ class RegionalPromptData:
# regions to be lost entirely.
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
# potentially use a weighted mask rather than a binary mask.
batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
return batch_sample_masks_by_seq_len
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
attn_masks_by_seq_len[query_seq_len] = torch.cat(
[batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
)
# for query_seq_len in batch_sample_masks_by_seq_len[0].keys():
# masks_by_seq_len[query_seq_len] = torch.cat(
# [batch_sample_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_sample_masks_by_seq_len))]
# )
return cls(attn_masks_by_seq_len)
# return masks_by_seq_len
def get_attn_mask(self, query_seq_len: int) -> torch.Tensor:
"""Get the attention mask for the given query sequence length (i.e. downscaling level).
# @classmethod
# def from_regions(
# cls,
# regions: list[TextConditioningRegions],
# key_seq_len: int,
# max_downscale_factor: int = 8,
# ):
# """Construct a `RegionalPromptData` object.
This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so
it changes at each downscaling level in the model.
# Args:
# regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
# batch.
# """
# attn_masks_by_seq_len = {}
key_seq_len is the length of the expected prompt embeddings.
# # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
# # length of s.
# batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
# for batch_sample_regions in regions:
# batch_attn_masks_by_seq_len.append({})
# # Convert the bool masks to float masks so that max pooling can be applied.
# batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
# # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
# downscale_factor = 1
# while downscale_factor <= max_downscale_factor:
# _, num_prompts, h, w = batch_masks.shape
# query_seq_len = h * w
# # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
# batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
# # Create a cross-attention mask for each prompt that selects the corresponding embeddings from
# # `encoder_hidden_states`.
# # attn_mask shape: (batch_size, query_seq_len, key_seq_len)
# # TODO(ryand): What device / dtype should this be?
# attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
# for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
# attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
# :, prompt_idx, :, :
# ]
# batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
# downscale_factor *= 2
# if downscale_factor <= max_downscale_factor:
# # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
# # regions to be lost entirely.
# # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
# # potentially use a weighted mask rather than a binary mask.
# batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
# # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
# for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
# attn_masks_by_seq_len[query_seq_len] = torch.cat(
# [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
# )
# return cls(attn_masks_by_seq_len)
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
"""Get the cross-attention mask for the given query sequence length.
Args:
query_seq_len: The length of the flattened spatial features at the current downscaling level.
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
Returns:
torch.Tensor: The masks.
@ -90,4 +160,58 @@ class RegionalPromptData:
dtype: float
The mask is a binary mask with values of 0.0 and 1.0.
"""
return self._attn_masks_by_seq_len[query_seq_len]
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len))
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
batch_sample_regions = self._regions[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_masks[
0, prompt_idx, :, :
]
return attn_mask
def get_self_attn_mask(self, query_seq_len: int) -> torch.Tensor:
"""Get the self-attention mask for the given query sequence length.
Args:
query_seq_len: The length of the flattened spatial features at the current downscaling level.
Returns:
torch.Tensor: The masks.
shape: (batch_size, query_seq_len, query_seq_len).
dtype: float
The mask is a binary mask with values of 0.0 and 1.0.
"""
batch_size = len(self._spatial_masks_by_seq_len)
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
# Create an empty attention mask with the correct shape.
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len))
for batch_idx in range(batch_size):
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
for prompt_idx in range(num_prompts):
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
# query_seq_len) mask.
attn_mask[batch_idx, :, :] += prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * 0.5
# Since we were adding masks in the previous loop, we need to clamp the values to 1.0.
# attn_mask[attn_mask > 0.5] = 1.0
return attn_mask

View File

@ -346,9 +346,7 @@ class InvokeAIDiffuserComponent:
regions.append(r)
_, key_seq_len, _ = both_conditionings.shape
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
regions=regions, key_seq_len=key_seq_len
)
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions)
both_results = self.model_forward_callback(
x_twice,