diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 99ef1d49bc..770c71f110 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -636,7 +636,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): del remapped_original_attn_slice, modified_attn_slice - attn_slice = torch.bmm(attn_slice, original_value[start_idx:end_idx]) + attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice