mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(experimental): Possible fix for conflict with regional embed length mismatch
Pushing this so people can test it out and see if this needs to be handled in a different way.
This commit is contained in:
parent
9cb0f63c44
commit
8426f1e7b2
@ -1,3 +1,4 @@
|
|||||||
|
from itertools import cycle, islice
|
||||||
from typing import List, Optional, TypedDict, cast
|
from typing import List, Optional, TypedDict, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -137,12 +138,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
if self._ip_adapter_attention_weights:
|
if self._ip_adapter_attention_weights:
|
||||||
assert regional_ip_data is not None
|
assert regional_ip_data is not None
|
||||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||||
|
|
||||||
|
# Pad weight tensor list to match size of regional embeds
|
||||||
|
self._ip_adapter_attention_weights["ip_adapter_weights"] = list(
|
||||||
|
islice(
|
||||||
|
cycle(self._ip_adapter_attention_weights["ip_adapter_weights"]),
|
||||||
|
len(regional_ip_data.image_prompt_embeds),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
len(regional_ip_data.image_prompt_embeds)
|
len(regional_ip_data.image_prompt_embeds)
|
||||||
== len(self._ip_adapter_attention_weights["ip_adapter_weights"])
|
== len(self._ip_adapter_attention_weights["ip_adapter_weights"])
|
||||||
== len(regional_ip_data.scales)
|
== len(regional_ip_data.scales)
|
||||||
== ip_masks.shape[1]
|
== ip_masks.shape[1]
|
||||||
)
|
)
|
||||||
|
|
||||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||||
ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index]
|
ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index]
|
||||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||||
|
Loading…
Reference in New Issue
Block a user