mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: entire reshaping block needs to be skipped
This commit is contained in:
parent
7ee3fef2db
commit
d27907cc6d
@ -178,15 +178,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||||
batch_size, -1, attn.heads * head_dim
|
batch_size, -1, attn.heads * head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||||
else:
|
else:
|
||||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||||
assert regional_ip_data is None
|
assert regional_ip_data is None
|
||||||
|
Loading…
Reference in New Issue
Block a user