fix: entire reshaping block needs to be skipped

This commit is contained in:
blessedcoolant 2024-04-16 04:29:53 +05:30
parent 7ee3fef2db
commit d27907cc6d

View File

@ -178,15 +178,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
) )
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim batch_size, -1, attn.heads * head_dim
) )
ip_hidden_states = ip_hidden_states.to(query.dtype) ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else: else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in. # If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None assert regional_ip_data is None