mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
(minor) Make it more clear that shape annotations are just comments and not commented lines of code.
This commit is contained in:
parent
53b6f0dc73
commit
3079c75a60
@ -138,29 +138,29 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
|
||||
ip_hidden_states = ipa_embed
|
||||
|
||||
# ip_hidden_state.shape = (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# ip_key.shape, ip_value.shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# ip_key.shape, ip_value.shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
# ip_hidden_states.shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
# ip_hidden_states.shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||
|
||||
hidden_states = hidden_states + scale * ip_hidden_states
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user