Handle conditioned and unconditioned text conditioning in the same way for regional prompt attention.

This commit is contained in:
Ryan Dick 2024-02-27 18:16:01 -05:00
parent b0fcbe552e
commit 2966c8de2c
3 changed files with 275 additions and 87 deletions

View File

@ -65,6 +65,7 @@ class IPAdapterConditioningInfo:
@dataclass
class ConditioningData:
# TODO(ryand): Support masks for unconditioned_embeddings.
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: list[BasicConditioningInfo]
text_embedding_masks: list[Optional[torch.Tensor]]

View File

@ -15,78 +15,102 @@ class Range:
end: int
@dataclass
class RegionalPromptData:
# The region masks for each prompt.
# shape: (batch_size, num_prompts, height, width)
# dtype: float*
# The mask is set to 1.0 in regions where the prompt should be applied, and 0.0 elsewhere.
masks: torch.Tensor
def __init__(self, attn_masks_by_seq_len: dict[int, torch.Tensor]):
self._attn_masks_by_seq_len = attn_masks_by_seq_len
# The embedding ranges for each prompt.
# The i'th mask is applied to the embeddings in:
# encoder_hidden_states[:, embedding_ranges[i].start:embedding_ranges[i].end, :]
embedding_ranges: list[Range]
@classmethod
def from_masks_and_ranges(
cls,
masks: list[torch.Tensor],
embedding_ranges: list[list[Range]],
key_seq_len: int,
# TODO(ryand): Pass in a list of downscale factors?
max_downscale_factor: int = 8,
):
"""Construct a `RegionalPromptData` object.
Args:
masks (list[torch.Tensor]): masks[i] contains the regions masks for the i'th sample in the batch.
The shape of masks[i] is (num_prompts, height, width), and dtype=bool. The mask is set to True in
regions where the prompt should be applied, and 0.0 elsewhere.
embedding_ranges (list[list[Range]]): embedding_ranges[i][j] contains the embedding range for the j'th
prompt in the i'th batch sample. masks[i][j, ...] is applied to the embeddings in:
encoder_hidden_states[i, embedding_ranges[j].start:embedding_ranges[j].end, :].
key_seq_len (int): The sequence length of the expected prompt embeddings (which act as the key in the
cross-attention layers).
"""
attn_masks_by_seq_len = {}
# batch_attn_mask_by_seq_len[b][s] contains the attention mask for the b'th batch sample with a query sequence
# length of s.
batch_attn_masks_by_seq_len: list[dict[int, torch.Tensor]] = []
for batch_masks, batch_ranges in zip(masks, embedding_ranges, strict=True):
batch_attn_masks_by_seq_len.append({})
# Convert the bool masks to float masks so that max pooling can be applied.
batch_masks = batch_masks.to(dtype=torch.float32)
# Downsample the spatial dimensions by factors of 2 until max_downscale_factor is reached.
downscale_factor = 1
while downscale_factor <= max_downscale_factor:
_, num_prompts, h, w = batch_masks.shape
query_seq_len = h * w
# Flatten the spatial dimensions of the mask by reshaping to (1, num_prompts, query_seq_len, 1).
batch_query_masks = batch_masks.reshape((1, num_prompts, -1, 1))
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
# `encoder_hidden_states`.
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
# TODO(ryand): What device / dtype should this be?
attn_mask = torch.zeros((1, query_seq_len, key_seq_len))
for prompt_idx, embedding_range in enumerate(batch_ranges):
attn_mask[0, :, embedding_range.start : embedding_range.end] = batch_query_masks[
:, prompt_idx, :, :
]
batch_attn_masks_by_seq_len[-1][query_seq_len] = attn_mask
downscale_factor *= 2
if downscale_factor <= max_downscale_factor:
# We use max pooling because we downscale to a pretty low resolution, so we don't want small prompt
# regions to be lost entirely.
# TODO(ryand): In the future, we may want to experiment with other downsampling methods, and could
# potentially use a weighted mask rather than a binary mask.
batch_masks = F.max_pool2d(batch_masks, kernel_size=2, stride=2)
# Merge the batch_attn_masks_by_seq_len into a single attn_masks_by_seq_len.
for query_seq_len in batch_attn_masks_by_seq_len[0].keys():
attn_masks_by_seq_len[query_seq_len] = torch.cat(
[batch_attn_masks_by_seq_len[i][query_seq_len] for i in range(len(batch_attn_masks_by_seq_len))]
)
return cls(attn_masks_by_seq_len)
def get_attn_mask(self, query_seq_len: int) -> torch.Tensor:
"""Get the attention mask for the given query sequence length (i.e. downscaling level).
This is called during cross-attention, where query_seq_len is the length of the flattened spatial features, so
it changes at each downscaling level in the model.
key_seq_len is the length of the expected prompt embeddings.
Returns:
torch.Tensor: The masks.
shape: (batch_size, query_seq_len, key_seq_len).
dtype: float
The mask is a binary mask with values of 0.0 and 1.0.
"""
return self._attn_masks_by_seq_len[query_seq_len]
class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
"""An attention processor that supports regional prompt attention for PyTorch 2.0."""
def _prepare_regional_prompt_attention_mask(
self,
regional_prompt_data: RegionalPromptData,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
orig_attn_mask: torch.Tensor,
) -> torch.Tensor:
# Infer the current spatial dimensions from the shape of `hidden_states`.
_, query_seq_len, _ = hidden_states.shape
per_prompt_query_masks = regional_prompt_data.masks
_, _, h, w = per_prompt_query_masks.shape
# Downsample by factors of 2 until the spatial dimensions match the current query sequence length.
scale_factor = 1
while h * w > query_seq_len:
scale_factor *= 2
h //= 2
w //= 2
assert h * w == query_seq_len
# Convert the bool masks to float masks.
per_prompt_query_masks = per_prompt_query_masks.to(dtype=torch.float32)
# Apply max-pooling to resize the masks to the target spatial dimensions.
# TODO(ryand): We should be able to pre-compute all of the mask sizes. There's a lot of redundant computation
# here.
per_prompt_query_masks = F.max_pool2d(per_prompt_query_masks, kernel_size=scale_factor, stride=scale_factor)
batch_size, num_prompts, resized_h, resized_w = per_prompt_query_masks.shape
assert resized_h == h and resized_w == w
# Flatten the spatial dimensions of the masks.
# Shape after reshape: (batch_size, num_prompts, query_seq_len)
per_prompt_query_masks = per_prompt_query_masks.reshape((batch_size, num_prompts, -1, 1))
# Create a cross-attention mask for each prompt that selects the corresponding embeddings from
# `encoder_hidden_states`.
# attn_mask shape: (batch_size, query_seq_len, key_seq_len)
_, key_seq_len, _ = encoder_hidden_states.shape
# HACK(ryand): We are assuming the batch size.
attn_mask = torch.zeros((2, query_seq_len, key_seq_len), device=hidden_states.device)
for i, embedding_range in enumerate(regional_prompt_data.embedding_ranges):
# HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. This is too fragile
# to merge.
attn_mask[1, :, embedding_range.start : embedding_range.end] = per_prompt_query_masks[:, i, :, :]
# HACK(ryand): We are assuming that batch 0 is unconditioned and batch 1 is conditioned. We are also assuming
# the intent of attn_mask. And we shouldn't have to do this awkward mask type conversion.
orig_mask = torch.zeros_like(orig_attn_mask[0, ...])
orig_mask[orig_attn_mask[0, ...] > -0.5] = 1.0
attn_mask[0, ...] = orig_mask
return attn_mask > 0.5
def __call__(
self,
attn: Attention,
@ -114,9 +138,16 @@ class RegionalPromptAttnProcessor2_0(AttnProcessor2_0):
if encoder_hidden_states is not None:
assert regional_prompt_data is not None
assert attention_mask is not None
attention_mask = self._prepare_regional_prompt_attention_mask(
regional_prompt_data, hidden_states, encoder_hidden_states, attention_mask
_, query_seq_len, _ = hidden_states.shape
prompt_region_attention_mask = regional_prompt_data.get_attn_mask(query_seq_len)
# TODO(ryand): Avoid redundant type/device conversion here.
prompt_region_attention_mask = prompt_region_attention_mask.to(
dtype=attention_mask.dtype, device=attention_mask.device
)
prompt_region_attention_mask[prompt_region_attention_mask < 0.5] = -10000.0
prompt_region_attention_mask[prompt_region_attention_mask >= 0.5] = 0.0
attention_mask = prompt_region_attention_mask + attention_mask
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

View File

@ -36,6 +36,142 @@ ModelForwardCallback: TypeAlias = Union[
]
class RegionalTextConditioningInfo:
def __init__(
self,
text_conditioning: Union[BasicConditioningInfo, SDXLConditioningInfo],
masks: Optional[torch.Tensor] = None,
embedding_ranges: Optional[list[Range]] = None,
):
"""Initialize a RegionalTextConditioningInfo.
Args:
text_conditioning (Union[BasicConditioningInfo, SDXLConditioningInfo]): The text conditioning embeddings
after concatenating the embeddings for all regions.
masks (Optional[torch.Tensor], optional): Shape: (1, num_regions, h, w).
embedding_ranges (Optional[list[Range]], optional): The embedding range for each region.
"""
self.text_conditioning = text_conditioning
self.masks = masks
self.embedding_ranges = embedding_ranges
assert (self.masks is None) == (self.embedding_ranges is None)
if self.masks is not None:
assert self.masks.shape[1] == len(self.embedding_ranges)
def has_region_masks(self):
if self.masks is None:
return False
return any(mask is not None for mask in self.masks)
def is_sdxl(self):
return isinstance(self.text_conditioning, SDXLConditioningInfo)
@classmethod
def _preprocess_regional_prompt_mask(
cls, mask: Optional[torch.Tensor], target_height: int, target_width: int
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
If mask is not None, resizes the mask to the target height and width using nearest neighbor interpolation.
Returns:
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
"""
if mask is None:
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
tf = torchvision.transforms.Resize(
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
)
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
mask = tf(mask)
return mask
@classmethod
def from_text_conditioning_and_masks(
cls,
text_conditionings: list[Union[BasicConditioningInfo, SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
latent_width: int,
):
if masks is None:
masks = [None] * len(text_conditionings)
assert len(text_conditionings) == len(masks)
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
all_masks_are_none = all(mask is None for mask in masks)
text_embedding = []
pooled_embedding = None
add_time_ids = None
processed_masks = []
cur_text_embedding_len = 0
embedding_ranges: list[Range] = []
for text_embedding_info, mask in zip(text_conditionings, masks, strict=True):
# HACK(ryand): Figure out the intended relationship between CAC and other conditioning features.
assert (
text_embedding_info.extra_conditioning is None
or not text_embedding_info.extra_conditioning.wants_cross_attention_control
)
if is_sdxl:
# We just use the the first SDXLConditioningInfo's pooled_embeds and add_time_ids.
# TODO(ryand): Think about this some more. If we can't use the pooled_embeds and add_time_ids from all
# the conditioning info, then we shouldn't allow it to be passed in.
# How does Compel handle this? Options that come to mind:
# - Blend the pooled_embeds and add_time_ids from all of the text embeddings.
# - Use the pooled_embeds and add_time_ids from the text embedding with the largest mask area, since
# this is likely the global prompt.
if pooled_embedding is None:
pooled_embedding = text_embedding_info.pooled_embeds
if add_time_ids is None:
add_time_ids = text_embedding_info.add_time_ids
text_embedding.append(text_embedding_info.embeds)
embedding_ranges.append(
Range(start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1])
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
if not all_masks_are_none:
processed_masks.append(cls._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
text_embedding = torch.cat(text_embedding, dim=1)
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
if not all_masks_are_none:
processed_masks = torch.cat(processed_masks, dim=1)
else:
processed_masks = None
embedding_ranges = None
if is_sdxl:
return cls(
text_conditioning=SDXLConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
pooled_embeds=pooled_embedding,
add_time_ids=add_time_ids,
),
masks=processed_masks,
embedding_ranges=embedding_ranges,
)
return cls(
text_conditioning=BasicConditioningInfo(
embeds=text_embedding,
extra_conditioning=None,
),
masks=processed_masks,
embedding_ranges=embedding_ranges,
)
class InvokeAIDiffuserComponent:
"""
The aim of this component is to provide a single place for code that can be applied identically to
@ -59,7 +195,6 @@ class InvokeAIDiffuserComponent:
:param model_forward_callback: a lambda with arguments (x, sigma, conditioning_to_apply). will be called repeatedly. most likely, this should simply call model.forward(x, sigma, conditioning)
"""
config = InvokeAIAppConfig.get_config()
self.conditioning = None
self.model = model
self.model_forward_callback = model_forward_callback
self.cross_attention_control_context = None
@ -433,14 +568,44 @@ class InvokeAIDiffuserComponent:
# denoising step.
cross_attention_kwargs = None
_, _, h, w = x.shape
text_embeddings, regional_prompt_data = self._prepare_text_embeddings(
text_embeddings=conditioning_data.text_embeddings,
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
text_conditionings=conditioning_data.text_embeddings,
masks=conditioning_data.text_embedding_masks,
target_height=h,
target_width=w,
latent_height=h,
latent_width=w,
)
if regional_prompt_data is not None:
cross_attention_kwargs = {"regional_prompt_data": regional_prompt_data}
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
text_conditionings=[conditioning_data.unconditioned_embeddings],
masks=[None],
latent_height=h,
latent_width=w,
)
if cond_text.has_region_masks() or uncond_text.has_region_masks():
masks = []
embedding_ranges = []
for c in [uncond_text, cond_text]:
if c.has_region_masks():
masks.append(c.masks)
embedding_ranges.append(c.embedding_ranges)
else:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
masks.append(torch.ones((1, 1, h, w), dtype=torch.bool))
embedding_ranges.append([Range(start=0, end=c.text_conditioning.embeds.shape[1])])
# The key_seq_len will be the maximum sequence length of all the conditioning embeddings. All other
# embeddings will be padded to match this length.
key_seq_len = 0
for c in [uncond_text, cond_text]:
_, seq_len, _ = c.text_conditioning.embeds.shape
if seq_len > key_seq_len:
key_seq_len = seq_len
cross_attention_kwargs = {
"regional_prompt_data": RegionalPromptData.from_masks_and_ranges(
masks=masks, embedding_ranges=embedding_ranges, key_seq_len=key_seq_len
)
}
# TODO(ryand): Figure out interactions between regional prompting and IP-Adapter conditioning.
if conditioning_data.ip_adapter_conditioning is not None:
@ -455,27 +620,18 @@ class InvokeAIDiffuserComponent:
}
added_cond_kwargs = None
if type(text_embeddings) is SDXLConditioningInfo:
if cond_text.is_sdxl():
added_cond_kwargs = {
"text_embeds": torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds,
text_embeddings.pooled_embeds,
],
dim=0,
[uncond_text.text_conditioning.pooled_embeds, cond_text.text_conditioning.pooled_embeds], dim=0
),
"time_ids": torch.cat(
[
conditioning_data.unconditioned_embeddings.add_time_ids,
text_embeddings.add_time_ids,
],
dim=0,
[uncond_text.text_conditioning.add_time_ids, cond_text.text_conditioning.add_time_ids], dim=0
),
}
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds
uncond_text.text_conditioning.embeds, cond_text.text_conditioning.embeds
)
both_results = self.model_forward_callback(
x_twice,