fix: make ip-adapter weights not be optional

This commit is contained in:
blessedcoolant 2024-04-16 21:12:45 +05:30
parent d27907cc6d
commit f46bbaf8c4
2 changed files with 26 additions and 26 deletions

View File

@ -12,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import Reg
@dataclass
class IPAdapterAttentionWeights:
ip_adapter_weights: Optional[IPAttentionProcessorWeights]
ip_adapter_weights: IPAttentionProcessorWeights
skip: bool
@ -64,6 +64,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
is_cross_attention = encoder_hidden_states is not None
# Start unmodified block from AttnProcessor2_0.
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
@ -77,6 +78,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End unmodified block from AttnProcessor2_0.
_, query_seq_len, _ = hidden_states.shape
@ -160,33 +162,32 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
if not self._ip_adapter_attention_weights[ipa_index].skip:
if ipa_weights:
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
# Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
# Expected ip_key and ip_value shape:
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# Expected ip_key and ip_value shape:
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.to(query.dtype)
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
else:
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
assert regional_ip_data is None
@ -206,6 +207,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# End of unmodified block from AttnProcessor2_0
# casting torch.Tensor to torch.FloatTensor to avoid type issues

View File

@ -37,17 +37,15 @@ class UNetAttentionPatcher:
ip_adapter_attention_weights_collection: list[IPAdapterAttentionWeights] = []
for ip_adapter in self._ip_adapters:
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=None, skip=False
)
ip_adapter_weights = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
skip = True
for block in ip_adapter["target_blocks"]:
if block in name:
skip = False
break
ip_adapter_attention_weights.ip_adapter_weights = ip_adapter_weights
ip_adapter_attention_weights.skip = skip
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
ip_adapter_weights=ip_adapter_weights, skip=skip
)
ip_adapter_attention_weights_collection.append(ip_adapter_attention_weights)
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights_collection)