mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Implement (very slow) self-attention regional masking.
This commit is contained in:
parent
ffc4ebb14c
commit
942efa011e
@ -92,15 +92,28 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
# End unmodified block from AttnProcessor2_0.
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
# Handle regional prompt attention masks.
|
# Handle regional prompt attention masks.
|
||||||
if is_cross_attention and regional_prompt_data is not None:
|
if regional_prompt_data is not None:
|
||||||
_, query_seq_len, _ = hidden_states.shape
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
|
if is_cross_attention:
|
||||||
# TODO(ryand): Avoid redundant type/device conversion here.
|
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||||
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||||
dtype=encoder_hidden_states.dtype, device=encoder_hidden_states.device
|
)
|
||||||
)
|
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||||
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
|
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||||
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
|
||||||
|
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
||||||
|
|
||||||
|
else: # self-attention
|
||||||
|
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(query_seq_len=query_seq_len)
|
||||||
|
|
||||||
|
# TODO(ryand): Avoid redundant type/device conversion here.
|
||||||
|
prompt_region_attention_mask = prompt_region_attention_mask.to(
|
||||||
|
dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
# prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -0.5
|
||||||
|
# prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
|
||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = prompt_region_attention_mask
|
attention_mask = prompt_region_attention_mask
|
||||||
|
@ -5,60 +5,67 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
TextConditioningRegions,
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Stages:
|
||||||
|
# - Convert image masks to spatial masks at all downsampling factors.
|
||||||
|
# - Decision: Max pooling? Nearest? Other?
|
||||||
|
# - Should definitely be shared across all denoising steps - that should be easy.
|
||||||
|
# - Convert spatial masks to cross-attention masks.
|
||||||
|
# - This should ideally be shared across all denoising steps, but preparing the masks requires knowing the max_key_seq_len.
|
||||||
|
# - Could create it just-in-time and them cache the result
|
||||||
|
# - Convert spatial masks to self-attention masks.
|
||||||
|
# - This should be shared across all denoising steps.
|
||||||
|
# - Shape depends only on spatial resolution and downsampling factors.
|
||||||
|
# - Convert cross-attention binary mask to score mask.
|
||||||
|
# - Convert self-attention binary mask to score mask.
|
||||||
|
#
|
||||||
|
# If we wanted a time schedule for level of attenuation, we would apply that in the attention layer.
|
||||||
|
|
||||||
|
|
||||||
|
# Pre-compute the spatial masks, because that's easy.
|
||||||
|
# Compute the other stuff as it's requested. Add caching if we find that it's slow.
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptData:
|
class RegionalPromptData:
|
||||||
def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]):
|
def __init__(self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8):
|
||||||
self._attn_masks_by_seq_len = attn_masks_by_seq_len
|
"""Initialize a `RegionalPromptData` object.
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_regions(
|
|
||||||
cls,
|
|
||||||
regions: list[TextConditioningRegions],
|
|
||||||
key_seq_len: int,
|
|
||||||
# TODO(ryand): Pass in a list of downscale factors?
|
|
||||||
max_downscale_factor: int = 8,
|
|
||||||
):
|
|
||||||
"""Construct a `RegionalPromptData` object.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||||
batch.
|
batch.
|
||||||
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
|
max_downscale_factor: The maximum downscale factor to use when preparing the spatial masks.
|
||||||
cross-attention layers). This is most likely equal to the max embedding range end, but we pass it
|
|
||||||
explicitly to be sure.
|
|
||||||
"""
|
"""
|
||||||
attn_masks_by_seq_len = {}
|
self._regions = regions
|
||||||
|
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
||||||
|
# sequence length of s.
|
||||||
|
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
||||||
|
regions, max_downscale_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_spatial_masks(
|
||||||
|
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||||
|
) -> list[dict[int, torch.Tensor]]:
|
||||||
|
"""Prepare the spatial masks for all downscaling factors."""
|
||||||
|
# TODO(ryand): Pass in a list of downscale factors? IIRC, SDXL does not apply attention at all downscaling
|
||||||
|
# levels, but I need to double check that.
|
||||||
|
|
||||||
|
# batch_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query sequence length
|
||||||
|
# of s.
|
||||||
|
batch_sample_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||||
|
|
||||||
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
|
||||||
# length of s.
|
|
||||||
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
|
||||||
for batch_sample_regions in regions:
|
for batch_sample_regions in regions:
|
||||||
batch_attn_masks_by_seq_len.append({})
|
batch_sample_masks_by_seq_len.append({})
|
||||||
|
|
||||||
# Convert the bool masks to float masks so that max pooling can be applied.
|
# Convert the bool masks to float masks so that max pooling can be applied.
|
||||||
batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
batch_sample_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
||||||
|
|
||||||
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
downscale_factor = 1
|
downscale_factor = 1
|
||||||
while downscale_factor <= max_downscale_factor:
|
while downscale_factor <= max_downscale_factor:
|
||||||
_, num_prompts, h, w = batch_masks.shape
|
b, num_prompts, h, w = batch_sample_masks.shape
|
||||||
|
assert b == 1
|
||||||
query_seq_len = h * w
|
query_seq_len = h * w
|
||||||
|
|
||||||
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
batch_sample_masks_by_seq_len[-1][query_seq_len] = batch_sample_masks
|
||||||
batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
|
|
||||||
|
|
||||||
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
|
|
||||||
# `encoder_hidden_states`.
|
|
||||||
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
|
|
||||||
# TODO(ryand): What device / dtype should this be?
|
|
||||||
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
|
||||||
|
|
||||||
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
|
||||||
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
|
||||||
:, prompt_idx, :, :
|
|
||||||
]
|
|
||||||
|
|
||||||
batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
|
|
||||||
|
|
||||||
downscale_factor *= 2
|
downscale_factor *= 2
|
||||||
if downscale_factor <= max_downscale_factor:
|
if downscale_factor <= max_downscale_factor:
|
||||||
@ -66,23 +73,86 @@ class RegionalPromptData:
|
|||||||
# regions to be lost entirely.
|
# regions to be lost entirely.
|
||||||
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
||||||
# potentially use a weighted mask rather than a binary mask.
|
# potentially use a weighted mask rather than a binary mask.
|
||||||
batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
|
batch_sample_masks = F.max_pool2d(batch_sample_masks, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
return batch_sample_masks_by_seq_len
|
||||||
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
||||||
for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
|
# for query_seq_len in batch_sample_masks_by_seq_len[0].keys():
|
||||||
attn_masks_by_seq_len[query_seq_len] = torch.cat(
|
# masks_by_seq_len[query_seq_len] = torch.cat(
|
||||||
[batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
|
# [batch_sample_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_sample_masks_by_seq_len))]
|
||||||
)
|
# )
|
||||||
|
|
||||||
return cls(attn_masks_by_seq_len)
|
# return masks_by_seq_len
|
||||||
|
|
||||||
def get_attn_mask(self, query_seq_len: int) -> torch.Tensor:
|
# @classmethod
|
||||||
"""Get the attention mask for the given query sequence length (i.e. downscaling level).
|
# def from_regions(
|
||||||
|
# cls,
|
||||||
|
# regions: list[TextConditioningRegions],
|
||||||
|
# key_seq_len: int,
|
||||||
|
# max_downscale_factor: int = 8,
|
||||||
|
# ):
|
||||||
|
# """Construct a `RegionalPromptData` object.
|
||||||
|
|
||||||
This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so
|
# Args:
|
||||||
it changes at each downscaling level in the model.
|
# regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
||||||
|
# batch.
|
||||||
|
# """
|
||||||
|
# attn_masks_by_seq_len = {}
|
||||||
|
|
||||||
key_seq_len is the length of the expected prompt embeddings.
|
# # batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
|
||||||
|
# # length of s.
|
||||||
|
# batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
|
||||||
|
# for batch_sample_regions in regions:
|
||||||
|
# batch_attn_masks_by_seq_len.append({})
|
||||||
|
|
||||||
|
# # Convert the bool masks to float masks so that max pooling can be applied.
|
||||||
|
# batch_masks = batch_sample_regions.masks.to(dtype=torch.float32)
|
||||||
|
|
||||||
|
# # Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
|
||||||
|
# downscale_factor = 1
|
||||||
|
# while downscale_factor <= max_downscale_factor:
|
||||||
|
# _, num_prompts, h, w = batch_masks.shape
|
||||||
|
# query_seq_len = h * w
|
||||||
|
|
||||||
|
# # Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
# batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
|
||||||
|
|
||||||
|
# # Create a cross-attention mask for each prompt that selects the corresponding embeddings from
|
||||||
|
# # `encoder_hidden_states`.
|
||||||
|
# # attn_mask shape: (batch_size, query_seq_len, key_seq_len)
|
||||||
|
# # TODO(ryand): What device / dtype should this be?
|
||||||
|
# attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
|
||||||
|
|
||||||
|
# for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||||
|
# attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
|
||||||
|
# :, prompt_idx, :, :
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
|
||||||
|
|
||||||
|
# downscale_factor *= 2
|
||||||
|
# if downscale_factor <= max_downscale_factor:
|
||||||
|
# # We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
|
||||||
|
# # regions to be lost entirely.
|
||||||
|
# # TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
|
||||||
|
# # potentially use a weighted mask rather than a binary mask.
|
||||||
|
# batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
|
||||||
|
|
||||||
|
# # Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
|
||||||
|
# for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
|
||||||
|
# attn_masks_by_seq_len[query_seq_len] = torch.cat(
|
||||||
|
# [batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
|
||||||
|
# )
|
||||||
|
|
||||||
|
# return cls(attn_masks_by_seq_len)
|
||||||
|
|
||||||
|
def get_cross_attn_mask(self, query_seq_len: int, key_seq_len: int) -> torch.Tensor:
|
||||||
|
"""Get the cross-attention mask for the given query sequence length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||||
|
key_seq_len (int): The sequence length of the prompt embeddings (which act as the key in the cross-attention
|
||||||
|
layers). This is most likely equal to the max embedding range end, but we pass it explicitly to be sure.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The masks.
|
torch.Tensor: The masks.
|
||||||
@ -90,4 +160,58 @@ class RegionalPromptData:
|
|||||||
dtype: float
|
dtype: float
|
||||||
The mask is a binary mask with values of 0.0 and 1.0.
|
The mask is a binary mask with values of 0.0 and 1.0.
|
||||||
"""
|
"""
|
||||||
return self._attn_masks_by_seq_len[query_seq_len]
|
batch_size = len(self._spatial_masks_by_seq_len)
|
||||||
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
|
|
||||||
|
# Create an empty attention mask with the correct shape.
|
||||||
|
attn_mask = torch.zeros((batch_size, query_seq_len, key_seq_len))
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
|
batch_sample_regions = self._regions[batch_idx]
|
||||||
|
|
||||||
|
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||||
|
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||||
|
|
||||||
|
for prompt_idx, embedding_range in enumerate(batch_sample_regions.ranges):
|
||||||
|
attn_mask[batch_idx, :, embedding_range.start : embedding_range.end] = batch_sample_query_masks[
|
||||||
|
0, prompt_idx, :, :
|
||||||
|
]
|
||||||
|
|
||||||
|
return attn_mask
|
||||||
|
|
||||||
|
def get_self_attn_mask(self, query_seq_len: int) -> torch.Tensor:
|
||||||
|
"""Get the self-attention mask for the given query sequence length.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_seq_len: The length of the flattened spatial features at the current downscaling level.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The masks.
|
||||||
|
shape: (batch_size, query_seq_len, query_seq_len).
|
||||||
|
dtype: float
|
||||||
|
The mask is a binary mask with values of 0.0 and 1.0.
|
||||||
|
"""
|
||||||
|
batch_size = len(self._spatial_masks_by_seq_len)
|
||||||
|
batch_spatial_masks = [self._spatial_masks_by_seq_len[b][query_seq_len] for b in range(batch_size)]
|
||||||
|
|
||||||
|
# Create an empty attention mask with the correct shape.
|
||||||
|
attn_mask = torch.zeros((batch_size, query_seq_len, query_seq_len))
|
||||||
|
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
batch_sample_spatial_masks = batch_spatial_masks[batch_idx]
|
||||||
|
|
||||||
|
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
|
||||||
|
_, num_prompts, _, _ = batch_sample_spatial_masks.shape
|
||||||
|
batch_sample_query_masks = batch_sample_spatial_masks.view((1, num_prompts, query_seq_len, 1))
|
||||||
|
|
||||||
|
for prompt_idx in range(num_prompts):
|
||||||
|
prompt_query_mask = batch_sample_query_masks[0, prompt_idx, :, 0] # Shape: (query_seq_len,)
|
||||||
|
# Multiply a (1, query_seq_len) mask by a (query_seq_len, 1) mask to get a (query_seq_len,
|
||||||
|
# query_seq_len) mask.
|
||||||
|
attn_mask[batch_idx, :, :] += prompt_query_mask.unsqueeze(0) * prompt_query_mask.unsqueeze(1) * 0.5
|
||||||
|
|
||||||
|
# Since we were adding masks in the previous loop, we need to clamp the values to 1.0.
|
||||||
|
# attn_mask[attn_mask > 0.5] = 1.0
|
||||||
|
return attn_mask
|
||||||
|
@ -346,9 +346,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
regions.append(r)
|
regions.append(r)
|
||||||
|
|
||||||
_, key_seq_len, _ = both_conditionings.shape
|
_, key_seq_len, _ = both_conditionings.shape
|
||||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData.from_regions(
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(regions=regions)
|
||||||
regions=regions, key_seq_len=key_seq_len
|
|
||||||
)
|
|
||||||
|
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
|
Loading…
Reference in New Issue
Block a user