Get RegionalPromptAttnProcessor2_0 working with a ton of hacks.

This commit is contained in:
Ryan Dick 2024-02-17 19:56:37 -05:00
parent 2d5d370f38
commit d132fb4818

View File

@ -32,6 +32,58 @@ class RegionalPromptData:
class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
"""An attention processor that supports regional prompt attention for PyTorch 2.0."""
def _prepare_regional_prompt_attention_mask(
self,
regional_prompt_data: RegionalPromptData,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
orig_attn_mask: torch.Tensor,
) -> torch.Tensor:
# Infer the current spatial dimensions from the shape of `hidden_states`.
_, query_seq_len, _ = hidden_states.shape
per_prompt_query_masks = regional_prompt_data.masks
_, _, h, w = per_prompt_query_masks.shape
# Downsample by factors of 2 until the spatial dimensions match the current query sequence length.
scale_factor = 1
while h * w > query_seq_len:
scale_factor *= 2
h //= 2
w //= 2
assert h * w == query_seq_len
# Apply max-pooling to resize the masks to the target spatial dimensions.
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
# here.
per_prompt_query_masks = F.max_pool2d(per_prompt_query_masks, kernel_size=scale_factor, stride=scale_factor)
batch_size, num_prompts, resized_h, resized_w = per_prompt_query_masks.shape
assert resized_h == h and resized_w == w
# Flatten the spatial dimensions of the masks.
# Shape after reshape: (batch_size, num_prompts, query_seq_len)
per_prompt_query_masks = per_prompt_query_masks.reshape((batch_size, num_prompts, -1, 1))
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
# `encoder_hidden_states`.
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
_, key_seq_len, _ = encoder_hidden_states.shape
# HACK(ryand): We are assuming the batch size.
attn_mask = torch.zeros((2, query_seq_len, key_seq_len), device=hidden_states.device)
for i, embedding_range in enumerate(regional_prompt_data.embedding_ranges):
# HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. This is too fragile
# to merge.
attn_mask[1, :, embedding_range.start : embedding_range.end] = per_prompt_query_masks[:, i, :, :]
# HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. We are also assuming
# the intent of attn_mask. And we shouldn't have to do this awkward mask type conversion.
orig_mask = torch.zeros_like(orig_attn_mask[0, ...])
orig_mask[orig_attn_mask[0, ...] > -0.5] = 1.0
attn_mask[0, ...] = orig_mask
return attn_mask > 0.5
def __call__(
self,
attn: Attention,
@ -56,6 +108,13 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if encoder_hidden_states is not None:
assert regional_prompt_data is not None
assert attention_mask is not None
attention_mask = self._prepare_regional_prompt_attention_mask(
regional_prompt_data, hidden_states, encoder_hidden_states, attention_mask
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be