diff --git a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py index ac53bf911d..8d7245ae3b 100644 --- a/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py +++ b/invokeai/backend/stable_diffusion/diffusion/custom_atttention.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from itertools import cycle, islice from typing import List, Optional, cast import torch @@ -13,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import Reg @dataclass class IPAdapterAttentionWeights: - ip_adapter_weights: List[IPAttentionProcessorWeights] + ip_adapter_weights: Optional[IPAttentionProcessorWeights] skip: bool @@ -28,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): def __init__( self, - ip_adapter_attention_weights: Optional[IPAdapterAttentionWeights] = None, + ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None, ): """Initialize a CustomAttnProcessor2_0. Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are @@ -139,23 +138,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): assert regional_ip_data is not None ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) - # Pad weight tensor list to match size of regional embeds - self._ip_adapter_attention_weights.ip_adapter_weights = list( - islice( - cycle(self._ip_adapter_attention_weights.ip_adapter_weights), - len(regional_ip_data.image_prompt_embeds), - ) - ) - assert ( len(regional_ip_data.image_prompt_embeds) - == len(self._ip_adapter_attention_weights.ip_adapter_weights) + == len(self._ip_adapter_attention_weights) == len(regional_ip_data.scales) == ip_masks.shape[1] ) for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds): - ipa_weights = self._ip_adapter_attention_weights.ip_adapter_weights[ipa_index] + ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights ipa_scale = regional_ip_data.scales[ipa_index] ip_mask = ip_masks[0, ipa_index, ...] @@ -168,33 +159,33 @@ class CustomAttnProcessor2_0(AttnProcessor2_0): # Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding) - if not self._ip_adapter_attention_weights.skip: - ip_key = ipa_weights.to_k_ip(ip_hidden_states) - ip_value = ipa_weights.to_v_ip(ip_hidden_states) + if not self._ip_adapter_attention_weights[ipa_index].skip: + if ipa_weights: + ip_key = ipa_weights.to_k_ip(ip_hidden_states) + ip_value = ipa_weights.to_v_ip(ip_hidden_states) - # Expected ip_key and ip_value shape: - # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) + # Expected ip_key and ip_value shape: + # (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads) - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - # Expected ip_key and ip_value shape: - # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) + # Expected ip_key and ip_value shape: + # (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) + # TODO: add support for attn.scale when we move to Torch 2.1 + ip_hidden_states = F.scaled_dot_product_attention( + query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False + ) # Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim) - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( batch_size, -1, attn.heads * head_dim ) + ip_hidden_states = ip_hidden_states.to(query.dtype) # Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim) - hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask else: # If IP-Adapter is not enabled, then regional_ip_data should not be passed in. diff --git a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py index df9b8d6386..e94d78decb 100644 --- a/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py +++ b/invokeai/backend/stable_diffusion/diffusion/unet_attention_patcher.py @@ -33,20 +33,25 @@ class UNetAttentionPatcher: # "attn1" processors do not use IP-Adapters. attn_procs[name] = CustomAttnProcessor2_0() else: - ip_adapter_attention_weights = IPAdapterAttentionWeights(ip_adapter_weights=[], skip=False) + total_ip_adapter_attention_weights: list[IPAdapterAttentionWeights] = [] + for ip_adapter in self._ip_adapters: + ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights( + ip_adapter_weights=None, skip=False + ) ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx) skip = True for block in ip_adapter["target_blocks"]: if block in name: skip = False break - ip_adapter_attention_weights.ip_adapter_weights = [ip_adapter_weight] + ip_adapter_attention_weights.ip_adapter_weights = ip_adapter_weight ip_adapter_attention_weights.skip = skip + total_ip_adapter_attention_weights.append(ip_adapter_attention_weights) # Collect the weights from each IP Adapter for the idx'th attention processor. - attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights) + attn_procs[name] = CustomAttnProcessor2_0(total_ip_adapter_attention_weights) return attn_procs