From 688d7258f185035a5924bee23322b001574da931 Mon Sep 17 00:00:00 2001 From: damian0815 Date: Wed, 2 Nov 2022 00:33:00 +0100 Subject: [PATCH] fix a bug that broke cross attention control index mapping --- ldm/models/diffusion/cross_attention_control.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 1a161fbc86..9c8c597869 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -68,6 +68,8 @@ class CrossAttentionControl: indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 + cls.inject_attention_function(model) + for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): m.last_attn_slice_mask = None m.last_attn_slice_indices = None @@ -76,8 +78,6 @@ class CrossAttentionControl: m.last_attn_slice_mask = mask.to(device) m.last_attn_slice_indices = indices.to(device) - cls.inject_attention_function(model) - class CrossAttentionType(Enum): SELF = 1