This commit is contained in:
Ryan Dick 2024-03-06 10:52:35 -05:00
parent d3a40c5b2b
commit b8cbff828b

View File

@ -59,7 +59,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
scale: float = 1.0, scale: float = 1.0,
# For regional prompting: # For regional prompting:
regional_prompt_data: Optional[RegionalPromptData] = None, regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None, percent_through: Optional[float] = None,
# For IP-Adapter: # For IP-Adapter:
ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None, ip_adapter_image_prompt_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
@ -106,13 +106,13 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
dtype=hidden_states.dtype, device=hidden_states.device dtype=hidden_states.dtype, device=hidden_states.device
) )
attn_mask_weight = 1.0 attn_mask_weight = 0.8
else: # self-attention else: # self-attention
prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask( prompt_region_attention_mask = regional_prompt_data.get_self_attn_mask(
query_seq_len=query_seq_len, query_seq_len=query_seq_len,
percent_through=percent_through, percent_through=percent_through,
) )
attn_mask_weight = 0.3 attn_mask_weight = 0.5
if attn.group_norm is not None: if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
@ -142,7 +142,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# (batch, heads, source_length, target_length) # (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if regional_prompt_data is not None: if regional_prompt_data is not None and percent_through < 0.5:
prompt_region_attention_mask = attn.prepare_attention_mask( prompt_region_attention_mask = attn.prepare_attention_mask(
prompt_region_attention_mask, sequence_length, batch_size prompt_region_attention_mask, sequence_length, batch_size
) )
@ -161,10 +161,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask) m_pos * prompt_region_attention_mask - m_neg * (1.0 - prompt_region_attention_mask)
) )
if attention_mask is None: if attention_mask is None:
attention_mask = prompt_region_attention_mask attention_mask = prompt_region_attention_mask
else:
attention_mask = prompt_region_attention_mask + attention_mask
else: else:
attention_mask = prompt_region_attention_mask + attention_mask pass
# the output of sdp = (batch, num_heads, seq_len, head_dim) # the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: add support for attn.scale when we move to Torch 2.1