diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 08c62060c9..8b5467e85f 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -346,7 +346,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers old_attn_processors = unet.attn_processors # try to re-use an existing slice size default_slice_size = 4 - slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) + slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) return old_attn_processors else: @@ -654,22 +654,23 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) + dim = query.shape[-1] query = attn.head_to_batch_dim(query) original_text_embeddings = encoder_hidden_states - original_text_key = attn.to_k(original_text_embeddings) - original_text_key = attn.head_to_batch_dim(original_text_key) modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings + original_text_key = attn.to_k(original_text_embeddings) modified_text_key = attn.to_k(modified_text_embeddings) - modified_text_key = attn.head_to_batch_dim(modified_text_key) + #original_value = attn.to_v(original_text_embeddings) + modified_value = attn.to_v(modified_text_embeddings) - # for the "value" just use the modified text embeddings. - value = attn.to_v(modified_text_embeddings) - value = attn.head_to_batch_dim(value) + original_text_key = attn.head_to_batch_dim(original_text_key) + modified_text_key = attn.head_to_batch_dim(modified_text_key) + #original_value = attn.head_to_batch_dim(original_value) + modified_value = attn.head_to_batch_dim(modified_value) # compute slices and prepare output tensor batch_size_attention = query.shape[0] - dim = query.shape[-1] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype ) @@ -677,36 +678,31 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # do slices for i in range(hidden_states.shape[0] // self.slice_size): start_idx = i * self.slice_size - end_idx = min(hidden_states.shape[0], (i + 1) * self.slice_size) + end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] - attention_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None + original_key_slice = original_text_key[start_idx:end_idx] + modified_key_slice = modified_text_key[start_idx:end_idx] + attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None - # first, find attention probabilities for the "original" prompt - original_text_key_slice = original_text_key[start_idx:end_idx] - original_attention_probs_slice = attn.get_attention_scores(query_slice, original_text_key_slice, attention_mask_slice) - - # then, find attention probabilities for the "modified" prompt - modified_text_key_slice = modified_text_key[start_idx:end_idx] - modified_attention_probs_slice = attn.get_attention_scores(query_slice, modified_text_key_slice, attention_mask_slice) + original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice) + modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice) # because the prompt modifications may result in token sequences shifted forwards or backwards, # the original attention probabilities must be remapped to account for token index changes in the # modified prompt - remapped_original_attention_probs_slice = torch.index_select(original_attention_probs_slice, -1, - swap_cross_attn_context.index_map) + remapped_original_attn_slice = torch.index_select(original_attn_slice, -1, + swap_cross_attn_context.index_map) # only some tokens taken from the original attention probabilities. this is controlled by the mask. mask = swap_cross_attn_context.mask inverse_mask = 1 - mask - attention_probs_slice = \ - remapped_original_attention_probs_slice * mask + \ - modified_attention_probs_slice * inverse_mask + attn_slice = \ + remapped_original_attn_slice * mask + \ + modified_attn_slice * inverse_mask - value_slice = value[start_idx:end_idx] - hidden_states_slice = torch.bmm(attention_probs_slice, value_slice) - - hidden_states[start_idx:end_idx] = hidden_states_slice + attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) + hidden_states[start_idx:end_idx] = attn_slice # done