This commit is contained in:
Ryan Dick 2024-03-06 10:52:35 -05:00
parent d3a40c5b2b
commit b8cbff828b

View File

@ -59,7 +59,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
scale: float = 1.0,
# For regional prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None,
percent_through: Optional[float] = None,
# For IP-Adapter:
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.FloatTensor:
@ -106,13 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
dtype=hidden_states.dtype, device=hidden_states.device
)
attn_mask_weight = 1.0
attn_mask_weight = 0.8
else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len,
percent_through=percent_through,
)
attn_mask_weight = 0.3
attn_mask_weight = 0.5
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@ -142,7 +142,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if regional_prompt_data is not None:
if regional_prompt_data is not None and percent_through < 0.5:
prompt_region_attention_mask = attn.prepare_attention_mask(
prompt_region_attention_mask, sequence_length, batch_size
)
@ -161,10 +161,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
)
if attention_mask is None:
attention_mask = prompt_region_attention_mask
if attention_mask is None:
attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
pass
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1