mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
sliced swap working
This commit is contained in:
parent
c52dd7e3f4
commit
1f5ad1b05e
@ -346,7 +346,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
|
||||
old_attn_processors = unet.attn_processors
|
||||
# try to re-use an existing slice size
|
||||
default_slice_size = 4
|
||||
slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||
return old_attn_processors
|
||||
else:
|
||||
@ -654,22 +654,23 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
dim = query.shape[-1]
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
original_text_embeddings = encoder_hidden_states
|
||||
original_text_key = attn.to_k(original_text_embeddings)
|
||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||
original_text_key = attn.to_k(original_text_embeddings)
|
||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||
#original_value = attn.to_v(original_text_embeddings)
|
||||
modified_value = attn.to_v(modified_text_embeddings)
|
||||
|
||||
# for the "value" just use the modified text embeddings.
|
||||
value = attn.to_v(modified_text_embeddings)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||
#original_value = attn.head_to_batch_dim(original_value)
|
||||
modified_value = attn.head_to_batch_dim(modified_value)
|
||||
|
||||
# compute slices and prepare output tensor
|
||||
batch_size_attention = query.shape[0]
|
||||
dim = query.shape[-1]
|
||||
hidden_states = torch.zeros(
|
||||
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||
)
|
||||
@ -677,36 +678,31 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
# do slices
|
||||
for i in range(hidden_states.shape[0] // self.slice_size):
|
||||
start_idx = i * self.slice_size
|
||||
end_idx = min(hidden_states.shape[0], (i + 1) * self.slice_size)
|
||||
end_idx = (i + 1) * self.slice_size
|
||||
|
||||
query_slice = query[start_idx:end_idx]
|
||||
attention_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
original_key_slice = original_text_key[start_idx:end_idx]
|
||||
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||
|
||||
# first, find attention probabilities for the "original" prompt
|
||||
original_text_key_slice = original_text_key[start_idx:end_idx]
|
||||
original_attention_probs_slice = attn.get_attention_scores(query_slice, original_text_key_slice, attention_mask_slice)
|
||||
|
||||
# then, find attention probabilities for the "modified" prompt
|
||||
modified_text_key_slice = modified_text_key[start_idx:end_idx]
|
||||
modified_attention_probs_slice = attn.get_attention_scores(query_slice, modified_text_key_slice, attention_mask_slice)
|
||||
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||
|
||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||
# the original attention probabilities must be remapped to account for token index changes in the
|
||||
# modified prompt
|
||||
remapped_original_attention_probs_slice = torch.index_select(original_attention_probs_slice, -1,
|
||||
swap_cross_attn_context.index_map)
|
||||
remapped_original_attn_slice = torch.index_select(original_attn_slice, -1,
|
||||
swap_cross_attn_context.index_map)
|
||||
|
||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||
mask = swap_cross_attn_context.mask
|
||||
inverse_mask = 1 - mask
|
||||
attention_probs_slice = \
|
||||
remapped_original_attention_probs_slice * mask + \
|
||||
modified_attention_probs_slice * inverse_mask
|
||||
attn_slice = \
|
||||
remapped_original_attn_slice * mask + \
|
||||
modified_attn_slice * inverse_mask
|
||||
|
||||
value_slice = value[start_idx:end_idx]
|
||||
hidden_states_slice = torch.bmm(attention_probs_slice, value_slice)
|
||||
|
||||
hidden_states[start_idx:end_idx] = hidden_states_slice
|
||||
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
||||
hidden_states[start_idx:end_idx] = attn_slice
|
||||
|
||||
|
||||
# done
|
||||
|
Loading…
Reference in New Issue
Block a user