(minor) Make it more clear that shape annotations are just comments and not commented lines of code.

This commit is contained in:
Ryan Dick 2023-10-16 08:35:32 -04:00
parent 53b6f0dc73
commit 3079c75a60

View File

@ -138,29 +138,29 @@ class IPAttnProcessor2_0(torch.nn.Module):
ip_hidden_states = ipa_embed ip_hidden_states = ipa_embed
# ip_hidden_state.shape = (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# ip_key.shape, ip_value.shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# ip_key.shape, ip_value.shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention( ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
) )
# ip_hidden_states.shape: (batch_size, num_heads, query_seq_len, head_dim) # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype) ip_hidden_states = ip_hidden_states.to(query.dtype)
# ip_hidden_states.shape: (batch_size, query_seq_len, num_heads * head_dim) # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + scale * ip_hidden_states hidden_states = hidden_states + scale * ip_hidden_states