diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 2a15e4fbe2..1dc4a43b2f 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -1,4 +1,4 @@ -from typing import List, Optional, TypedDict +from typing import List, Optional, TypedDict, cast import torch import torch.nn.functional as F @@ -40,15 +40,17 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): def __call__( self, attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: Optional[torch.FloatTensor] = None, - attention_mask: Optional[torch.FloatTensor] = None, - temb: Optional[torch.FloatTensor] = None, - # For regional prompting: + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + # For Regional Prompting: regional_prompt_data: Optional[RegionalPromptData] = None, - percent_through: Optional[torch.FloatTensor] = None, + percent_through: Optional[torch.Tensor] = None, # For IP-Adapter: regional_ip_data: Optional[RegionalIPData] = None, + *args, + **kwargs, ) -> torch.FloatTensor: """Apply attention. Args: @@ -155,16 +157,18 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - if self._ip_adapter_attention_weights["skip"]: + if not self._ip_adapter_attention_weights["skip"]: ip_key = ipa_weights.to_k_ip(ip_hidden_states) ip_value = ipa_weights.to_v_ip(ip_hidden_states) - # Expected ip_key and ip_value shape: (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) + # Expected ip_key and ip_value shape: + # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Expected ip_key and ip_value shape: (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) + # Expected ip_key and ip_value shape: + # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 ip_hidden_states = F.scaled_dot_product_attention( @@ -193,6 +197,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: @@ -200,4 +205,4 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): hidden_states = hidden_states / attn.rescale_output_factor - return hidden_states + return cast(torch.FloatTensor, hidden_states) diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index 05011e3d9a..52cfc2c573 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -36,10 +36,10 @@ class UNetAttentionPatcher: ip_adapter_attention_weights: IPAdapterAttentionWeights = {"ip_adapter_weights": [], "skip": False} for ip_adapter in self._ip_adapters: ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx) - skip = False + skip = True for block in ip_adapter["target_blocks"]: if block in name: - skip = True + skip = False break ip_adapter_attention_weights.update({"ip_adapter_weights": [ip_adapter_weight], "skip": skip})