mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(experimental): Possible fix for conflict with regional embed length mismatch
Pushing this so people can test it out and see if this needs to be handled in a different way.
This commit is contained in:
parent
9cb0f63c44
commit
8426f1e7b2
@ -1,3 +1,4 @@
|
||||
from itertools import cycle, islice
|
||||
from typing import List, Optional, TypedDict, cast
|
||||
|
||||
import torch
|
||||
@ -137,12 +138,22 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
||||
if self._ip_adapter_attention_weights:
|
||||
assert regional_ip_data is not None
|
||||
ip_masks = regional_ip_data.get_masks(query_seq_len=query_seq_len)
|
||||
|
||||
# Pad weight tensor list to match size of regional embeds
|
||||
self._ip_adapter_attention_weights["ip_adapter_weights"] = list(
|
||||
islice(
|
||||
cycle(self._ip_adapter_attention_weights["ip_adapter_weights"]),
|
||||
len(regional_ip_data.image_prompt_embeds),
|
||||
)
|
||||
)
|
||||
|
||||
assert (
|
||||
len(regional_ip_data.image_prompt_embeds)
|
||||
== len(self._ip_adapter_attention_weights["ip_adapter_weights"])
|
||||
== len(regional_ip_data.scales)
|
||||
== ip_masks.shape[1]
|
||||
)
|
||||
|
||||
for ipa_index, ipa_embed in enumerate(regional_ip_data.image_prompt_embeds):
|
||||
ipa_weights = self._ip_adapter_attention_weights["ip_adapter_weights"][ipa_index]
|
||||
ipa_scale = regional_ip_data.scales[ipa_index]
|
||||
|
Loading…
Reference in New Issue
Block a user