mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: entire reshaping block needs to be skipped
This commit is contained in:
parent
7ee3fef2db
commit
d27907cc6d
@ -178,15 +178,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||
batch_size, -1, attn.heads * head_dim
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||
else:
|
||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||
assert regional_ip_data is None
|
||||
|
Loading…
Reference in New Issue
Block a user