mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
refactor: fix a bunch of type issues in custom_attention
This commit is contained in:
parent
2d5786d3bb
commit
9cb0f63c44
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional, TypedDict
|
from typing import List, Optional, TypedDict, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@ -40,15 +40,17 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
hidden_states: torch.FloatTensor,
|
hidden_states: torch.Tensor,
|
||||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
temb: Optional[torch.FloatTensor] = None,
|
temb: Optional[torch.Tensor] = None,
|
||||||
# For regional prompting:
|
# For Regional Prompting:
|
||||||
regional_prompt_data: Optional[RegionalPromptData] = None,
|
regional_prompt_data: Optional[RegionalPromptData] = None,
|
||||||
percent_through: Optional[torch.FloatTensor] = None,
|
percent_through: Optional[torch.Tensor] = None,
|
||||||
# For IP-Adapter:
|
# For IP-Adapter:
|
||||||
regional_ip_data: Optional[RegionalIPData] = None,
|
regional_ip_data: Optional[RegionalIPData] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
"""Apply attention.
|
"""Apply attention.
|
||||||
Args:
|
Args:
|
||||||
@ -155,16 +157,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
if self._ip_adapter_attention_weights["skip"]:
|
if not self._ip_adapter_attention_weights["skip"]:
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
# Expected ip_key and ip_value shape:
|
||||||
|
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
# Expected ip_key and ip_value shape:
|
||||||
|
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||||
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
@ -193,6 +197,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
hidden_states = attn.to_out[1](hidden_states)
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
if input_ndim == 4:
|
if input_ndim == 4:
|
||||||
|
batch_size, channel, height, width = hidden_states.shape
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||||
|
|
||||||
if attn.residual_connection:
|
if attn.residual_connection:
|
||||||
@ -200,4 +205,4 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
hidden_states = hidden_states / attn.rescale_output_factor
|
||||||
|
|
||||||
return hidden_states
|
return cast(torch.FloatTensor, hidden_states)
|
||||||
|
@ -36,10 +36,10 @@ class UNetAttentionPatcher:
|
|||||||
ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False}
|
ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False}
|
||||||
for ip_adapter in self._ip_adapters:
|
for ip_adapter in self._ip_adapters:
|
||||||
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||||
skip = False
|
skip = True
|
||||||
for block in ip_adapter["target_blocks"]:
|
for block in ip_adapter["target_blocks"]:
|
||||||
if block in name:
|
if block in name:
|
||||||
skip = True
|
skip = False
|
||||||
break
|
break
|
||||||
|
|
||||||
ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip})
|
ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip})
|
||||||
|
Loading…
Reference in New Issue
Block a user