fix a bug that broke cross attention control index mapping

This commit is contained in:
damian0815 2022-11-02 00:33:00 +01:00 committed by blessedcoolant
parent 4513320bf1
commit 688d7258f1

View File

@ -68,6 +68,8 @@ class CrossAttentionControl:
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
cls.inject_attention_function(model)
for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
@ -76,8 +78,6 @@ class CrossAttentionControl:
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
cls.inject_attention_function(model)
class CrossAttentionType(Enum):
SELF = 1