mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
07a3df6001
Signed-off-by: Ben Alkov <ben.alkov@gmail.com>
178 lines
7.3 KiB
Python
178 lines
7.3 KiB
Python
# Functions supporting Cross-Attention Control
|
|
# Copied from https://github.com/bloc97/CrossAttentionControl
|
|
|
|
from difflib import SequenceMatcher
|
|
|
|
import torch
|
|
|
|
|
|
def prompt_token(prompt, index, clip_tokenizer):
|
|
tokens = clip_tokenizer(prompt,
|
|
padding='max_length',
|
|
max_length=clip_tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors='pt',
|
|
return_overflowing_tokens=True
|
|
).input_ids[0]
|
|
return clip_tokenizer.decode(tokens[index:index+1])
|
|
|
|
|
|
def init_attention_weights(weight_tuples, clip_tokenizer, unet, device):
|
|
tokens_length = clip_tokenizer.model_max_length
|
|
weights = torch.ones(tokens_length)
|
|
|
|
for i, w in weight_tuples:
|
|
if i < tokens_length and i >= 0:
|
|
weights[i] = w
|
|
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
module.last_attn_slice_weights = weights.to(device)
|
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
module.last_attn_slice_weights = None
|
|
|
|
|
|
def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device):
|
|
tokens_length = clip_tokenizer.model_max_length
|
|
mask = torch.zeros(tokens_length)
|
|
indices_target = torch.arange(tokens_length, dtype=torch.long)
|
|
indices = torch.zeros(tokens_length, dtype=torch.long)
|
|
|
|
tokens = tokens.input_ids.numpy()[0]
|
|
tokens_edit = tokens_edit.input_ids.numpy()[0]
|
|
|
|
for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
|
|
if b0 < tokens_length:
|
|
if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0):
|
|
mask[b0:b1] = 1
|
|
indices[b0:b1] = indices_target[a0:a1]
|
|
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
module.last_attn_slice_mask = mask.to(device)
|
|
module.last_attn_slice_indices = indices.to(device)
|
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
module.last_attn_slice_mask = None
|
|
module.last_attn_slice_indices = None
|
|
|
|
|
|
def init_attention_func(unet):
|
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
|
def new_attention(self, query, key, value):
|
|
# TODO: use baddbmm for better performance
|
|
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
|
|
attn_slice = attention_scores.softmax(dim=-1)
|
|
# compute attention output
|
|
|
|
if self.use_last_attn_slice:
|
|
if self.last_attn_slice_mask is not None:
|
|
new_attn_slice = (torch.index_select(self.last_attn_slice, -1,
|
|
self.last_attn_slice_indices))
|
|
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
|
|
+ new_attn_slice * self.last_attn_slice_mask)
|
|
else:
|
|
attn_slice = self.last_attn_slice
|
|
|
|
self.use_last_attn_slice = False
|
|
|
|
if self.save_last_attn_slice:
|
|
self.last_attn_slice = attn_slice
|
|
self.save_last_attn_slice = False
|
|
|
|
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
|
attn_slice = attn_slice * self.last_attn_slice_weights
|
|
self.use_last_attn_weights = False
|
|
|
|
hidden_states = torch.matmul(attn_slice, value)
|
|
# reshape hidden_states
|
|
return self.reshape_batch_dim_to_heads(hidden_states)
|
|
|
|
def new_sliced_attention(self, query, key, value, sequence_length, dim):
|
|
|
|
batch_size_attention = query.shape[0]
|
|
hidden_states = torch.zeros(
|
|
(batch_size_attention, sequence_length, dim // self.heads),
|
|
device=query.device, dtype=query.dtype
|
|
)
|
|
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
|
|
for i in range(hidden_states.shape[0] // slice_size):
|
|
start_idx = i * slice_size
|
|
end_idx = (i + 1) * slice_size
|
|
attn_slice = (
|
|
torch.matmul(query[start_idx:end_idx],
|
|
key[start_idx:end_idx].transpose(1, 2)) * self.scale
|
|
) # TODO: use baddbmm for better performance
|
|
attn_slice = attn_slice.softmax(dim=-1)
|
|
|
|
if self.use_last_attn_slice:
|
|
if self.last_attn_slice_mask is not None:
|
|
new_attn_slice = (torch.index_select(self.last_attn_slice,
|
|
-1, self.last_attn_slice_indices))
|
|
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
|
|
+ new_attn_slice * self.last_attn_slice_mask)
|
|
else:
|
|
attn_slice = self.last_attn_slice
|
|
|
|
self.use_last_attn_slice = False
|
|
|
|
if self.save_last_attn_slice:
|
|
self.last_attn_slice = attn_slice
|
|
self.save_last_attn_slice = False
|
|
|
|
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
|
|
attn_slice = attn_slice * self.last_attn_slice_weights
|
|
self.use_last_attn_weights = False
|
|
|
|
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
|
|
|
|
hidden_states[start_idx:end_idx] = attn_slice
|
|
|
|
return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states
|
|
|
|
for _, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention':
|
|
module.last_attn_slice = None
|
|
module.use_last_attn_slice = False
|
|
module.use_last_attn_weights = False
|
|
module.save_last_attn_slice = False
|
|
module._sliced_attention = new_sliced_attention.__get__(module, type(module))
|
|
module._attention = new_attention.__get__(module, type(module))
|
|
|
|
|
|
def use_last_tokens_attention(unet, use=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
module.use_last_attn_slice = use
|
|
|
|
|
|
def use_last_tokens_attention_weights(unet, use=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
module.use_last_attn_weights = use
|
|
|
|
|
|
def use_last_self_attention(unet, use=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
module.use_last_attn_slice = use
|
|
|
|
|
|
def save_last_tokens_attention(unet, save=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn2' in name:
|
|
module.save_last_attn_slice = save
|
|
|
|
|
|
def save_last_self_attention(unet, save=True):
|
|
for name, module in unet.named_modules():
|
|
module_name = type(module).__name__
|
|
if module_name == 'CrossAttention' and 'attn1' in name:
|
|
module.save_last_attn_slice = save
|