refactor: fix a bunch of type issues in custom_attention

This commit is contained in:
blessedcoolant 2024-04-13 14:17:25 +05:30
parent 2d5786d3bb
commit 9cb0f63c44
2 changed files with 18 additions and 13 deletions

View File

@ -1,4 +1,4 @@
from typing import List, Optional, TypedDict
from typing import List, Optional, TypedDict, cast
import torch
import torch.nn.functional as F
@ -40,15 +40,17 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
# For regional prompting:
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
# For Regional Prompting:
regional_prompt_data: Optional[RegionalPromptData] = None,
percent_through: Optional[torch.FloatTensor] = None,
percent_through: Optional[torch.Tensor] = None,
# For IP-Adapter:
regional_ip_data: Optional[RegionalIPData] = None,
*args,
**kwargs,
) -> torch.FloatTensor:
"""Apply attention.
Args:
@ -155,16 +157,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
if self._ip_adapter_attention_weights["skip"]:
if not self._ip_adapter_attention_weights["skip"]:
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
# Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
@ -193,6 +197,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
@ -200,4 +205,4 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
return cast(torch.FloatTensor, hidden_states)

View File

@ -36,10 +36,10 @@ class UNetAttentionPatcher:
ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False}
for ip_adapter in self._ip_adapters:
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = False
skip = True
for block in ip_adapter["target_blocks"]:
if block in name:
skip = True
skip = False
break
ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip})