InvokeAI/c_a_c.py

178 lines
7.3 KiB
Python
Raw Normal View History

# Functions supporting Cross-Attention Control
# Copied from https://github.com/bloc97/CrossAttentionControl
from difflib import SequenceMatcher
import torch
def prompt_token(prompt, index, clip_tokenizer):
tokens = clip_tokenizer(prompt,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt',
return_overflowing_tokens=True
).input_ids[0]
return clip_tokenizer.decode(tokens[index:index+1])
def init_attention_weights(weight_tuples, clip_tokenizer, unet, device):
tokens_length = clip_tokenizer.model_max_length
weights = torch.ones(tokens_length)
for i, w in weight_tuples:
if i < tokens_length and i >= 0:
weights[i] = w
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.last_attn_slice_weights = weights.to(device)
if module_name == 'CrossAttention' and 'attn1' in name:
module.last_attn_slice_weights = None
def init_attention_edit(tokens, tokens_edit, clip_tokenizer, unet, device):
tokens_length = clip_tokenizer.model_max_length
mask = torch.zeros(tokens_length)
indices_target = torch.arange(tokens_length, dtype=torch.long)
indices = torch.zeros(tokens_length, dtype=torch.long)
tokens = tokens.input_ids.numpy()[0]
tokens_edit = tokens_edit.input_ids.numpy()[0]
for name, a0, a1, b0, b1 in SequenceMatcher(None, tokens, tokens_edit).get_opcodes():
if b0 < tokens_length:
if name == 'equal' or (name == 'replace' and a1-a0 == b1-b0):
mask[b0:b1] = 1
indices[b0:b1] = indices_target[a0:a1]
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.last_attn_slice_mask = mask.to(device)
module.last_attn_slice_indices = indices.to(device)
if module_name == 'CrossAttention' and 'attn1' in name:
module.last_attn_slice_mask = None
module.last_attn_slice_indices = None
def init_attention_func(unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def new_attention(self, query, key, value):
# TODO: use baddbmm for better performance
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attn_slice = attention_scores.softmax(dim=-1)
# compute attention output
if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None:
new_attn_slice = (torch.index_select(self.last_attn_slice, -1,
self.last_attn_slice_indices))
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
+ new_attn_slice * self.last_attn_slice_mask)
else:
attn_slice = self.last_attn_slice
self.use_last_attn_slice = False
if self.save_last_attn_slice:
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
hidden_states = torch.matmul(attn_slice, value)
# reshape hidden_states
return self.reshape_batch_dim_to_heads(hidden_states)
def new_sliced_attention(self, query, key, value, sequence_length, dim):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads),
device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.matmul(query[start_idx:end_idx],
key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
if self.use_last_attn_slice:
if self.last_attn_slice_mask is not None:
new_attn_slice = (torch.index_select(self.last_attn_slice,
-1, self.last_attn_slice_indices))
attn_slice = (attn_slice * (1 - self.last_attn_slice_mask)
+ new_attn_slice * self.last_attn_slice_mask)
else:
attn_slice = self.last_attn_slice
self.use_last_attn_slice = False
if self.save_last_attn_slice:
self.last_attn_slice = attn_slice
self.save_last_attn_slice = False
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
attn_slice = attn_slice * self.last_attn_slice_weights
self.use_last_attn_weights = False
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
return self.reshape_batch_dim_to_heads(hidden_states) # reshape hidden_states
for _, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention':
module.last_attn_slice = None
module.use_last_attn_slice = False
module.use_last_attn_weights = False
module.save_last_attn_slice = False
module._sliced_attention = new_sliced_attention.__get__(module, type(module))
module._attention = new_attention.__get__(module, type(module))
def use_last_tokens_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_slice = use
def use_last_tokens_attention_weights(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_weights = use
def use_last_self_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.use_last_attn_slice = use
def save_last_tokens_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.save_last_attn_slice = save
def save_last_self_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.save_last_attn_slice = save