mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
sliced swap working
This commit is contained in:
parent
c52dd7e3f4
commit
1f5ad1b05e
@ -346,7 +346,7 @@ def setup_cross_attention_control(model, context: Context, is_running_diffusers
|
|||||||
old_attn_processors = unet.attn_processors
|
old_attn_processors = unet.attn_processors
|
||||||
# try to re-use an existing slice size
|
# try to re-use an existing slice size
|
||||||
default_slice_size = 4
|
default_slice_size = 4
|
||||||
slice_size = next((p for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size)
|
||||||
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size))
|
||||||
return old_attn_processors
|
return old_attn_processors
|
||||||
else:
|
else:
|
||||||
@ -654,22 +654,23 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
query = attn.to_q(hidden_states)
|
||||||
|
dim = query.shape[-1]
|
||||||
query = attn.head_to_batch_dim(query)
|
query = attn.head_to_batch_dim(query)
|
||||||
|
|
||||||
original_text_embeddings = encoder_hidden_states
|
original_text_embeddings = encoder_hidden_states
|
||||||
original_text_key = attn.to_k(original_text_embeddings)
|
|
||||||
original_text_key = attn.head_to_batch_dim(original_text_key)
|
|
||||||
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||||
|
original_text_key = attn.to_k(original_text_embeddings)
|
||||||
modified_text_key = attn.to_k(modified_text_embeddings)
|
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||||
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
#original_value = attn.to_v(original_text_embeddings)
|
||||||
|
modified_value = attn.to_v(modified_text_embeddings)
|
||||||
|
|
||||||
# for the "value" just use the modified text embeddings.
|
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||||
value = attn.to_v(modified_text_embeddings)
|
modified_text_key = attn.head_to_batch_dim(modified_text_key)
|
||||||
value = attn.head_to_batch_dim(value)
|
#original_value = attn.head_to_batch_dim(original_value)
|
||||||
|
modified_value = attn.head_to_batch_dim(modified_value)
|
||||||
|
|
||||||
# compute slices and prepare output tensor
|
# compute slices and prepare output tensor
|
||||||
batch_size_attention = query.shape[0]
|
batch_size_attention = query.shape[0]
|
||||||
dim = query.shape[-1]
|
|
||||||
hidden_states = torch.zeros(
|
hidden_states = torch.zeros(
|
||||||
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||||
)
|
)
|
||||||
@ -677,36 +678,31 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
# do slices
|
# do slices
|
||||||
for i in range(hidden_states.shape[0] // self.slice_size):
|
for i in range(hidden_states.shape[0] // self.slice_size):
|
||||||
start_idx = i * self.slice_size
|
start_idx = i * self.slice_size
|
||||||
end_idx = min(hidden_states.shape[0], (i + 1) * self.slice_size)
|
end_idx = (i + 1) * self.slice_size
|
||||||
|
|
||||||
query_slice = query[start_idx:end_idx]
|
query_slice = query[start_idx:end_idx]
|
||||||
attention_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
original_key_slice = original_text_key[start_idx:end_idx]
|
||||||
|
modified_key_slice = modified_text_key[start_idx:end_idx]
|
||||||
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
|
|
||||||
# first, find attention probabilities for the "original" prompt
|
original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice)
|
||||||
original_text_key_slice = original_text_key[start_idx:end_idx]
|
modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice)
|
||||||
original_attention_probs_slice = attn.get_attention_scores(query_slice, original_text_key_slice, attention_mask_slice)
|
|
||||||
|
|
||||||
# then, find attention probabilities for the "modified" prompt
|
|
||||||
modified_text_key_slice = modified_text_key[start_idx:end_idx]
|
|
||||||
modified_attention_probs_slice = attn.get_attention_scores(query_slice, modified_text_key_slice, attention_mask_slice)
|
|
||||||
|
|
||||||
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||||
# the original attention probabilities must be remapped to account for token index changes in the
|
# the original attention probabilities must be remapped to account for token index changes in the
|
||||||
# modified prompt
|
# modified prompt
|
||||||
remapped_original_attention_probs_slice = torch.index_select(original_attention_probs_slice, -1,
|
remapped_original_attn_slice = torch.index_select(original_attn_slice, -1,
|
||||||
swap_cross_attn_context.index_map)
|
swap_cross_attn_context.index_map)
|
||||||
|
|
||||||
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||||
mask = swap_cross_attn_context.mask
|
mask = swap_cross_attn_context.mask
|
||||||
inverse_mask = 1 - mask
|
inverse_mask = 1 - mask
|
||||||
attention_probs_slice = \
|
attn_slice = \
|
||||||
remapped_original_attention_probs_slice * mask + \
|
remapped_original_attn_slice * mask + \
|
||||||
modified_attention_probs_slice * inverse_mask
|
modified_attn_slice * inverse_mask
|
||||||
|
|
||||||
value_slice = value[start_idx:end_idx]
|
attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx])
|
||||||
hidden_states_slice = torch.bmm(attention_probs_slice, value_slice)
|
hidden_states[start_idx:end_idx] = attn_slice
|
||||||
|
|
||||||
hidden_states[start_idx:end_idx] = hidden_states_slice
|
|
||||||
|
|
||||||
|
|
||||||
# done
|
# done
|
||||||
|
Loading…
x
Reference in New Issue
Block a user