sliced swap working

This commit is contained in:
Damian Stewart 2023-01-25 21:38:27 +01:00
parent c52dd7e3f4
commit 1f5ad1b05e

View File

@ -346,7 +346,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
old_attn_processors = unet.attn_processors
# try to re-use an existing slice size
default_slice_size = 4
slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
return old_attn_processors
else:
@ -654,22 +654,23 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
dim = query.shape[-1]
query = attn.head_to_batch_dim(query)
original_text_embeddings = encoder_hidden_states
original_text_key = attn.to_k(original_text_embeddings)
original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
original_text_key = attn.to_k(original_text_embeddings)
modified_text_key = attn.to_k(modified_text_embeddings)
modified_text_key = attn.head_to_batch_dim(modified_text_key)
#original_value = attn.to_v(original_text_embeddings)
modified_value = attn.to_v(modified_text_embeddings)
# for the "value" just use the modified text embeddings.
value = attn.to_v(modified_text_embeddings)
value = attn.head_to_batch_dim(value)
original_text_key = attn.head_to_batch_dim(original_text_key)
modified_text_key = attn.head_to_batch_dim(modified_text_key)
#original_value = attn.head_to_batch_dim(original_value)
modified_value = attn.head_to_batch_dim(modified_value)
# compute slices and prepare output tensor
batch_size_attention = query.shape[0]
dim = query.shape[-1]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
)
@ -677,36 +678,31 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
# do slices
for i in range(hidden_states.shape[0] // self.slice_size):
start_idx = i * self.slice_size
end_idx = min(hidden_states.shape[0], (i + 1) * self.slice_size)
end_idx = (i + 1) * self.slice_size
query_slice = query[start_idx:end_idx]
attention_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
original_key_slice = original_text_key[start_idx:end_idx]
modified_key_slice = modified_text_key[start_idx:end_idx]
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
# first, find attention probabilities for the "original" prompt
original_text_key_slice = original_text_key[start_idx:end_idx]
original_attention_probs_slice = attn.get_attention_scores(query_slice, original_text_key_slice, attention_mask_slice)
# then, find attention probabilities for the "modified" prompt
modified_text_key_slice = modified_text_key[start_idx:end_idx]
modified_attention_probs_slice = attn.get_attention_scores(query_slice, modified_text_key_slice, attention_mask_slice)
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
# because the prompt modifications may result in token sequences shifted forwards or backwards,
# the original attention probabilities must be remapped to account for token index changes in the
# modified prompt
remapped_original_attention_probs_slice = torch.index_select(original_attention_probs_slice, -1,
swap_cross_attn_context.index_map)
remapped_original_attn_slice = torch.index_select(original_attn_slice, -1,
swap_cross_attn_context.index_map)
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
mask = swap_cross_attn_context.mask
inverse_mask = 1 - mask
attention_probs_slice = \
remapped_original_attention_probs_slice * mask + \
modified_attention_probs_slice * inverse_mask
attn_slice = \
remapped_original_attn_slice * mask + \
modified_attn_slice * inverse_mask
value_slice = value[start_idx:end_idx]
hidden_states_slice = torch.bmm(attention_probs_slice, value_slice)
hidden_states[start_idx:end_idx] = hidden_states_slice
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# done