fix(experimental): Possible fix for conflict with regional embed length mismatch

Pushing this so people can test it out and see if this needs to be handled in a different way.
This commit is contained in:
blessedcoolant 2024-04-14 12:19:19 +05:30
parent 9cb0f63c44
commit 8426f1e7b2

View File

@ -1,3 +1,4 @@
from itertools import cycle, islice
from typing import List, Optional, TypedDict, cast from typing import List, Optional, TypedDict, cast
import torch import torch
@ -137,12 +138,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
if self._ip_adapter_attention_weights: if self._ip_adapter_attention_weights:
assert regional_ip_data is not None assert regional_ip_data is not None
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len) ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
# Pad weight tensor list to match size of regional embeds
self._ip_adapter_attention_weights["ip_adapter_weights"] = list(
islice(
cycle(self._ip_adapter_attention_weights["ip_adapter_weights"]),
len(regional_ip_data.image_prompt_embeds),
)
)
assert ( assert (
len(regional_ip_data.image_prompt_embeds) len(regional_ip_data.image_prompt_embeds)
== len(self._ip_adapter_attention_weights["ip_adapter_weights"]) == len(self._ip_adapter_attention_weights["ip_adapter_weights"])
== len(regional_ip_data.scales) == len(regional_ip_data.scales)
== ip_masks.shape[1] == ip_masks.shape[1]
) )
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds): for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index] ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index]
ipa_scale = regional_ip_data.scales[ipa_index] ipa_scale = regional_ip_data.scales[ipa_index]