refactor: fix a bunch of type issues in custom_attention

This commit is contained in:
blessedcoolant 2024-04-13 14:17:25 +05:30
parent 2d5786d3bb
commit 9cb0f63c44
2 changed files with 18 additions and 13 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Optional, TypedDict from typing import List, Optional, TypedDict, cast
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -40,15 +40,17 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
def __call__( def __call__(
self, self,
attn: Attention, attn: Attention,
hidden_states: torch.FloatTensor, hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.FloatTensor] = None, temb: Optional[torch.Tensor] = None,
# For regional prompting: # For Regional Prompting:
regional_prompt_data: Optional[RegionalPromptData] = None, regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None, percent_through: Optional[torch.Tensor] = None,
# For IP-Adapter: # For IP-Adapter:
regional_ip_data: Optional[RegionalIPData] = None, regional_ip_data: Optional[RegionalIPData] = None,
*args,
**kwargs,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
"""Apply attention. """Apply attention.
Args: Args:
@ -155,16 +157,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
if self._ip_adapter_attention_weights["skip"]: if not self._ip_adapter_attention_weights["skip"]:
ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) # Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) # Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1 # TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention( ip_hidden_states = F.scaled_dot_product_attention(
@ -193,6 +197,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = attn.to_out[1](hidden_states) hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4: if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection: if attn.residual_connection:
@ -200,4 +205,4 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = hidden_states / attn.rescale_output_factor hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states return cast(torch.FloatTensor, hidden_states)

View File

@ -36,10 +36,10 @@ class UNetAttentionPatcher:
ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False} ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False}
for ip_adapter in self._ip_adapters: for ip_adapter in self._ip_adapters:
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx) ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = False skip = True
for block in ip_adapter["target_blocks"]: for block in ip_adapter["target_blocks"]:
if block in name: if block in name:
skip = True skip = False
break break
ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip}) ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip})