diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index 16617d049a..1334313fe6 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -12,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import Reg @dataclass class IPAdapterAttentionWeights: - ip_adapter_weights: Optional[IPAttentionProcessorWeights] + ip_adapter_weights: IPAttentionProcessorWeights skip: bool @@ -64,6 +64,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): is_cross_attention = encoder_hidden_states is not None # Start unmodified block from AttnProcessor2_0. + # vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) @@ -77,6 +78,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # End unmodified block from AttnProcessor2_0. _, query_seq_len, _ = hidden_states.shape @@ -160,33 +162,32 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) if not self._ip_adapter_attention_weights[ipa_index].skip: - if ipa_weights: - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) + ip_key = ipa_weights.to_k_ip(ip_hidden_states) + ip_value = ipa_weights.to_v_ip(ip_hidden_states) - # Expected ip_key and ip_value shape: - # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) + # Expected ip_key and ip_value shape: + # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Expected ip_key and ip_value shape: - # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) + # Expected ip_key and ip_value shape: + # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) - # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( - batch_size, -1, attn.heads * head_dim - ) + # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) + ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) - ip_hidden_states = ip_hidden_states.to(query.dtype) + ip_hidden_states = ip_hidden_states.to(query.dtype) - # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask + # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) + hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask else: # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. assert regional_ip_data is None @@ -206,6 +207,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # End of unmodified block from AttnProcessor2_0 # casting torch.Tensor to torch.FloatTensor to avoid type issues diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index f74359c614..ac00a8e06e 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -37,17 +37,15 @@ class UNetAttentionPatcher: ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = [] for ip_adapter in self._ip_adapters: - ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights( - ip_adapter_weights=None, skip=False - ) ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx) skip = True for block in ip_adapter["target_blocks"]: if block in name: skip = False break - ip_adapter_attention_weights.ip_adapter_weights = ip_adapter_weights - ip_adapter_attention_weights.skip = skip + ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights( + ip_adapter_weights=ip_adapter_weights, skip=skip + ) ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights) attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)