diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 1a161fbc86..9c8c597869 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -68,6 +68,8 @@ class CrossAttentionControl: indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 + cls.inject_attention_function(model) + for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): m.last_attn_slice_mask = None m.last_attn_slice_indices = None @@ -76,8 +78,6 @@ class CrossAttentionControl: m.last_attn_slice_mask = mask.to(device) m.last_attn_slice_indices = indices.to(device) - cls.inject_attention_function(model) - class CrossAttentionType(Enum): SELF = 1