mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
sliced attention processor wip (untested)
This commit is contained in:
parent
c0610f7cb9
commit
63c6019f92
@ -514,7 +514,7 @@ from dataclasses import field, dataclass
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor
|
from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor
|
||||||
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
|
from ldm.models.diffusion.cross_attention_control import CrossAttentionType
|
||||||
|
|
||||||
|
|
||||||
@ -625,3 +625,95 @@ class SwapCrossAttnProcessor(CrossAttnProcessor):
|
|||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||||
|
|
||||||
|
def __init__(self, slice_size = 1e6):
|
||||||
|
self.slice_count = slice_size
|
||||||
|
|
||||||
|
# TODO: dynamically pick slice size based on memory conditions
|
||||||
|
|
||||||
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None,
|
||||||
|
# kwargs
|
||||||
|
swap_cross_attn_context: SwapCrossAttnContext=None):
|
||||||
|
|
||||||
|
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||||
|
|
||||||
|
# if cross-attention control is not in play, just call through to the base implementation.
|
||||||
|
if attention_type is CrossAttentionType.SELF or \
|
||||||
|
swap_cross_attn_context is None or \
|
||||||
|
not swap_cross_attn_context.wants_cross_attention_control(attention_type):
|
||||||
|
#print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass")
|
||||||
|
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||||
|
#else:
|
||||||
|
# print(f"SwapCrossAttnContext for {attention_type} active")
|
||||||
|
|
||||||
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||||
|
|
||||||
|
query = attn.to_q(hidden_states)
|
||||||
|
query = attn.head_to_batch_dim(query)
|
||||||
|
|
||||||
|
original_text_embeddings = encoder_hidden_states
|
||||||
|
original_text_key = attn.to_k(original_text_embeddings)
|
||||||
|
original_text_key = attn.head_to_batch_dim(original_text_key)
|
||||||
|
modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings
|
||||||
|
modified_text_key = attn.to_k(modified_text_embeddings)
|
||||||
|
modified_text_key = attn.head_to_batch_dim(original_text_key)
|
||||||
|
|
||||||
|
# for the "value" just use the modified text embeddings.
|
||||||
|
value = attn.to_v(modified_text_embeddings)
|
||||||
|
value = attn.head_to_batch_dim(value)
|
||||||
|
|
||||||
|
# compute slices and prepare output tensor
|
||||||
|
batch_size_attention = query.shape[0]
|
||||||
|
dim = query.shape[-1]
|
||||||
|
hidden_states = torch.zeros(
|
||||||
|
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# do slices
|
||||||
|
for i in range(hidden_states.shape[0] // self.slice_size):
|
||||||
|
start_idx = i * self.slice_size
|
||||||
|
end_idx = min(hidden_states.shape[0], (i + 1) * self.slice_size)
|
||||||
|
|
||||||
|
query_slice = query[start_idx:end_idx]
|
||||||
|
attention_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
||||||
|
|
||||||
|
# first, find attention probabilities for the "original" prompt
|
||||||
|
original_text_key_slice = original_text_key[start_idx:end_idx]
|
||||||
|
original_attention_probs_slice = attn.get_attention_scores(query_slice, original_text_key_slice, attention_mask_slice)
|
||||||
|
|
||||||
|
# then, find attention probabilities for the "modified" prompt
|
||||||
|
modified_text_key_slice = modified_text_key[start_idx:end_idx]
|
||||||
|
modified_attention_probs_slice = attn.get_attention_scores(query_slice, modified_text_key_slice, attention_mask_slice)
|
||||||
|
|
||||||
|
# because the prompt modifications may result in token sequences shifted forwards or backwards,
|
||||||
|
# the original attention probabilities must be remapped to account for token index changes in the
|
||||||
|
# modified prompt
|
||||||
|
remapped_original_attention_probs_slice = torch.index_select(original_attention_probs_slice, -1,
|
||||||
|
swap_cross_attn_context.index_map)
|
||||||
|
|
||||||
|
# only some tokens taken from the original attention probabilities. this is controlled by the mask.
|
||||||
|
mask = swap_cross_attn_context.mask
|
||||||
|
inverse_mask = 1 - mask
|
||||||
|
attention_probs_slice = \
|
||||||
|
remapped_original_attention_probs_slice * mask + \
|
||||||
|
modified_attention_probs_slice * inverse_mask
|
||||||
|
|
||||||
|
value_slice = value[start_idx:end_idx]
|
||||||
|
hidden_states_slice = torch.bmm(attention_probs_slice, value_slice)
|
||||||
|
|
||||||
|
hidden_states[start_idx:end_idx] = hidden_states_slice
|
||||||
|
|
||||||
|
|
||||||
|
# done
|
||||||
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
# linear proj
|
||||||
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
# dropout
|
||||||
|
hidden_states = attn.to_out[1](hidden_states)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
Loading…
x
Reference in New Issue
Block a user