mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix: IP Adapter weights being incorrectly applied
They were being overwritten rather than being appended
This commit is contained in:
parent
f6b7bc5d98
commit
a148c4322c
@ -1,5 +1,4 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import cycle, islice
|
|
||||||
from typing import List, Optional, cast
|
from typing import List, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -13,7 +12,7 @@ from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import Reg
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPAdapterAttentionWeights:
|
class IPAdapterAttentionWeights:
|
||||||
ip_adapter_weights: List[IPAttentionProcessorWeights]
|
ip_adapter_weights: Optional[IPAttentionProcessorWeights]
|
||||||
skip: bool
|
skip: bool
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +27,7 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
ip_adapter_attention_weights: Optional[IPAdapterAttentionWeights] = None,
|
ip_adapter_attention_weights: Optional[List[IPAdapterAttentionWeights]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize a CustomAttnProcessor2_0.
|
"""Initialize a CustomAttnProcessor2_0.
|
||||||
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
Note: Arguments that are the same for all attention layers are passed to __call__(). Arguments that are
|
||||||
@ -139,23 +138,15 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
assert regional_ip_data is not None
|
assert regional_ip_data is not None
|
||||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||||
|
|
||||||
# Pad weight tensor list to match size of regional embeds
|
|
||||||
self._ip_adapter_attention_weights.ip_adapter_weights = list(
|
|
||||||
islice(
|
|
||||||
cycle(self._ip_adapter_attention_weights.ip_adapter_weights),
|
|
||||||
len(regional_ip_data.image_prompt_embeds),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(regional_ip_data.image_prompt_embeds)
|
len(regional_ip_data.image_prompt_embeds)
|
||||||
== len(self._ip_adapter_attention_weights.ip_adapter_weights)
|
== len(self._ip_adapter_attention_weights)
|
||||||
== len(regional_ip_data.scales)
|
== len(regional_ip_data.scales)
|
||||||
== ip_masks.shape[1]
|
== ip_masks.shape[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||||
ipa_weights = self._ip_adapter_attention_weights.ip_adapter_weights[ipa_index]
|
ipa_weights = self._ip_adapter_attention_weights[ipa_index].ip_adapter_weights
|
||||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||||
ip_mask = ip_masks[0, ipa_index, ...]
|
ip_mask = ip_masks[0, ipa_index, ...]
|
||||||
|
|
||||||
@ -168,33 +159,33 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
|
|
||||||
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
# Expected ip_hidden_state shape: (batch_size, num_ip_images, ip_seq_len, ip_image_embedding)
|
||||||
|
|
||||||
if not self._ip_adapter_attention_weights.skip:
|
if not self._ip_adapter_attention_weights[ipa_index].skip:
|
||||||
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
if ipa_weights:
|
||||||
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
ip_key = ipa_weights.to_k_ip(ip_hidden_states)
|
||||||
|
ip_value = ipa_weights.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape:
|
# Expected ip_key and ip_value shape:
|
||||||
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
# (batch_size, num_ip_images, ip_seq_len, head_dim * num_heads)
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# Expected ip_key and ip_value shape:
|
# Expected ip_key and ip_value shape:
|
||||||
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
# (batch_size, num_heads, num_ip_images * ip_seq_len, head_dim)
|
||||||
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
# Expected ip_hidden_states shape: (batch_size, num_heads, query_seq_len, head_dim)
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(
|
||||||
batch_size, -1, attn.heads * head_dim
|
batch_size, -1, attn.heads * head_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
# Expected ip_hidden_states shape: (batch_size, query_seq_len, num_heads * head_dim)
|
||||||
|
|
||||||
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
hidden_states = hidden_states + ipa_scale * ip_hidden_states * ip_mask
|
||||||
else:
|
else:
|
||||||
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
# If IP-Adapter is not enabled, then regional_ip_data should not be passed in.
|
||||||
|
@ -33,20 +33,25 @@ class UNetAttentionPatcher:
|
|||||||
# "attn1" processors do not use IP-Adapters.
|
# "attn1" processors do not use IP-Adapters.
|
||||||
attn_procs[name] = CustomAttnProcessor2_0()
|
attn_procs[name] = CustomAttnProcessor2_0()
|
||||||
else:
|
else:
|
||||||
ip_adapter_attention_weights = IPAdapterAttentionWeights(ip_adapter_weights=[], skip=False)
|
total_ip_adapter_attention_weights: list[IPAdapterAttentionWeights] = []
|
||||||
|
|
||||||
for ip_adapter in self._ip_adapters:
|
for ip_adapter in self._ip_adapters:
|
||||||
|
ip_adapter_attention_weights: IPAdapterAttentionWeights = IPAdapterAttentionWeights(
|
||||||
|
ip_adapter_weights=None, skip=False
|
||||||
|
)
|
||||||
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
ip_adapter_weight = ip_adapter["ip_adapter"].attn_weights.get_attention_processor_weights(idx)
|
||||||
skip = True
|
skip = True
|
||||||
for block in ip_adapter["target_blocks"]:
|
for block in ip_adapter["target_blocks"]:
|
||||||
if block in name:
|
if block in name:
|
||||||
skip = False
|
skip = False
|
||||||
break
|
break
|
||||||
ip_adapter_attention_weights.ip_adapter_weights = [ip_adapter_weight]
|
ip_adapter_attention_weights.ip_adapter_weights = ip_adapter_weight
|
||||||
ip_adapter_attention_weights.skip = skip
|
ip_adapter_attention_weights.skip = skip
|
||||||
|
total_ip_adapter_attention_weights.append(ip_adapter_attention_weights)
|
||||||
|
|
||||||
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
# Collect the weights from each IP Adapter for the idx'th attention processor.
|
||||||
|
|
||||||
attn_procs[name] = CustomAttnProcessor2_0(ip_adapter_attention_weights)
|
attn_procs[name] = CustomAttnProcessor2_0(total_ip_adapter_attention_weights)
|
||||||
|
|
||||||
return attn_procs
|
return attn_procs
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user