wip updating docs

This commit is contained in:
Damian Stewart
2023-01-25 23:49:38 +01:00
parent 93a24445dc
commit 5e7ed964d2
2 changed files with 21 additions and 25 deletions

View File

@ -594,12 +594,12 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
original_text_key = attn.to_k(original_text_embeddings)
modified_text_key = attn.to_k(modified_text_embeddings)
#original_value = attn.to_v(original_text_embeddings)
original_value = attn.to_v(original_text_embeddings)
modified_value = attn.to_v(modified_text_embeddings)
original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_key = attn.head_to_batch_dim(modified_text_key)
#original_value = attn.head_to_batch_dim(original_value)
original_value = attn.head_to_batch_dim(original_value)
modified_value = attn.head_to_batch_dim(modified_value)
# compute slices and prepare output tensor
@ -636,7 +636,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
del remapped_original_attn_slice, modified_attn_slice
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
attn_slice = torch.bmm(attn_slice, original_value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice