mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor: fix a bunch of type issues in custom_attention
This commit is contained in:
parent
2d5786d3bb
commit
9cb0f63c44
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional, TypedDict
|
||||
from typing import List, Optional, TypedDict, cast
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -40,15 +40,17 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
# For regional prompting:
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
# For Regional Prompting:
|
||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||
percent_through: Optional[torch.FloatTensor] = None,
|
||||
percent_through: Optional[torch.Tensor] = None,
|
||||
# For IP-Adapter:
|
||||
regional_ip_data: Optional[RegionalIPData] = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> torch.FloatTensor:
|
||||
"""Apply attention.
|
||||
Args:
|
||||
@ -155,16 +157,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||
|
||||
if self._ip_adapter_attention_weights["skip"]:
|
||||
if not self._ip_adapter_attention_weights["skip"]:
|
||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
# Expected ip_key and ip_value shape:
|
||||
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
@ -193,6 +197,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
@ -200,4 +205,4 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
return cast(torch.FloatTensor, hidden_states)
|
||||
|
@ -36,10 +36,10 @@ class UNetAttentionPatcher:
|
||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False}
|
||||
for ip_adapter in self._ip_adapters:
|
||||
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||
skip = False
|
||||
skip = True
|
||||
for block in ip_adapter["target_blocks"]:
|
||||
if block in name:
|
||||
skip = True
|
||||
skip = False
|
||||
break
|
||||
|
||||
ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip})
|
||||
|
Loading…
Reference in New Issue
Block a user