InvokeAI/ldm/models/diffusion/cross_attention.py

298 lines
14 KiB
Python
Raw Normal View History

2022-10-16 14:57:48 +00:00
from enum import Enum
import torch
2022-10-16 14:57:48 +00:00
class CrossAttentionControllableDiffusionMixin:
def setup_cross_attention_control_if_appropriate(self, model, edited_conditioning, edit_opcodes):
self.edited_conditioning = edited_conditioning
if edited_conditioning is not None:
# <start> a cat sitting on a car <end>
CrossAttentionControl.setup_attention_editing(model, edited_conditioning, edit_opcodes)
else:
# pass through the attention func but don't act on it
CrossAttentionControl.clear_attention_editing(model)
def cleanup_cross_attention_control(self, model):
CrossAttentionControl.clear_attention_editing(model)
def do_cross_attention_controllable_diffusion_step(self, x, sigma, unconditioning, conditioning, model, model_forward_callback):
CrossAttentionControl.clear_requests(model)
if self.edited_conditioning is None:
# faster batched path
x_twice = torch.cat([x]*2)
sigma_twice = torch.cat([sigma]*2)
both_conditionings = torch.cat([unconditioning, conditioning])
unconditioned_next_x, conditioned_next_x = model_forward_callback(x_twice, sigma_twice, both_conditionings).chunk(2)
else:
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
unconditioned_next_x = model_forward_callback(x, sigma, unconditioning)
# process x using the original prompt, saving the attention maps
CrossAttentionControl.request_save_attention_maps(model)
_ = model_forward_callback(x, sigma, cond=conditioning)
CrossAttentionControl.clear_requests(model)
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
CrossAttentionControl.request_apply_saved_attention_maps(model)
conditioned_next_x = model_forward_callback(x, sigma, self.edited_conditioning)
CrossAttentionControl.clear_requests(model)
return unconditioned_next_x, conditioned_next_x
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
2022-10-18 11:52:40 +00:00
@classmethod
def clear_attention_editing(cls, model):
cls.remove_attention_function(model)
2022-10-16 14:57:48 +00:00
@classmethod
def setup_attention_editing(cls, model,
2022-10-18 11:52:40 +00:00
substitute_conditioning: torch.Tensor,
edit_opcodes: list):
"""
:param model: The unet model to inject into.
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
:param edit_opcodes: Opcodes from difflib.SequenceMatcher describing how the base
conditionings map to the "edited" conditionings.
:return:
"""
# adapted from init_attention_edit
2022-10-18 11:52:40 +00:00
device = substitute_conditioning.device
# urgh. should this be hardcoded?
max_length = 77
2022-10-18 11:52:40 +00:00
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
2022-10-18 11:52:40 +00:00
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
2022-10-18 11:52:40 +00:00
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
2022-10-18 11:52:40 +00:00
cls.inject_attention_function(model)
2022-10-18 11:52:40 +00:00
class AttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def get_attention_modules(cls, model, which: AttentionType):
which_attn = "attn1" if which is cls.AttentionType.SELF else "attn2"
return [module for name, module in model.named_modules() if
type(module).__name__ == "CrossAttention" and which_attn in name]
2022-10-18 09:48:33 +00:00
@classmethod
def clear_requests(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.save_last_attn_slice = False
m.use_last_attn_slice = False
@classmethod
def request_save_attention_maps(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
2022-10-18 17:49:25 +00:00
# clear out the saved slice in case the outermost dim changes
m.last_attn_slice = None
m.save_last_attn_slice = True
@classmethod
def request_apply_saved_attention_maps(cls, model):
self_attention_modules = cls.get_attention_modules(model, cls.AttentionType.SELF)
tokens_attention_modules = cls.get_attention_modules(model, cls.AttentionType.TOKENS)
for m in self_attention_modules+tokens_attention_modules:
m.use_last_attn_slice = True
2022-10-18 11:52:40 +00:00
@classmethod
2022-10-18 11:52:40 +00:00
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
2022-10-17 23:54:30 +00:00
2022-10-18 09:48:33 +00:00
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
2022-10-18 17:49:25 +00:00
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
2022-10-18 09:48:33 +00:00
attn_slice = suggested_attention_slice
if dim is not None:
start = offset
end = start+slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
2022-10-17 23:54:30 +00:00
if self.use_last_attn_slice:
2022-10-18 09:48:33 +00:00
this_attn_slice = attn_slice
if self.last_attn_slice_mask is not None:
2022-10-18 09:48:33 +00:00
# indices and mask operate on dim=2, no need to slice
base_attn_slice_full = torch.index_select(self.last_attn_slice, -1, self.last_attn_slice_indices)
base_attn_slice_mask = self.last_attn_slice_mask
2022-10-18 09:48:33 +00:00
if dim is None:
base_attn_slice = base_attn_slice_full
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 0:
base_attn_slice = base_attn_slice_full[start:end]
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 1:
base_attn_slice = base_attn_slice_full[:, start:end]
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
attn_slice = this_attn_slice * (1 - base_attn_slice_mask) + \
base_attn_slice * base_attn_slice_mask
else:
2022-10-18 09:48:33 +00:00
if dim is None:
attn_slice = self.last_attn_slice
2022-10-18 09:48:33 +00:00
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 0:
attn_slice = self.last_attn_slice[start:end]
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 1:
attn_slice = self.last_attn_slice[:, start:end]
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
2022-10-18 09:48:33 +00:00
if self.save_last_attn_slice:
if dim is None:
self.last_attn_slice = attn_slice
2022-10-18 09:48:33 +00:00
elif dim == 0:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
elif self.last_attn_slice.shape[0] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=0)
assert(self.last_attn_slice.shape[0] == end)
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
else:
# no need to grow
self.last_attn_slice[start:end] = attn_slice
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
elif dim == 1:
# dynamically grow last_attn_slice if needed
if self.last_attn_slice is None:
self.last_attn_slice = attn_slice
elif self.last_attn_slice.shape[1] == start:
self.last_attn_slice = torch.cat([self.last_attn_slice, attn_slice], dim=1)
assert(self.last_attn_slice.shape[1] == end)
else:
# no need to grow
self.last_attn_slice[:, start:end] = attn_slice
2022-10-18 09:48:33 +00:00
if self.use_last_attn_weights and self.last_attn_slice_weights is not None:
if dim is None:
weights = self.last_attn_slice_weights
elif dim == 0:
weights = self.last_attn_slice_weights[start:end]
elif dim == 1:
weights = self.last_attn_slice_weights[:, start:end]
attn_slice = attn_slice * weights
2022-10-18 09:48:33 +00:00
return attn_slice
2022-10-17 23:54:30 +00:00
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
2022-10-18 11:52:40 +00:00
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
2022-10-18 11:52:40 +00:00
module.use_last_attn_slice = False
module.save_last_attn_slice = False
2022-10-18 09:48:33 +00:00
module.set_attention_slice_wrangler(attention_slice_wrangler)
2022-10-18 11:52:40 +00:00
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)
# original code below
# Functions supporting Cross-Attention Control
# Copied from https://github.com/bloc97/CrossAttentionControl
from difflib import SequenceMatcher
import torch
def prompt_token(prompt, index, clip_tokenizer):
tokens = clip_tokenizer(prompt,
padding='max_length',
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt',
return_overflowing_tokens=True
).input_ids[0]
return clip_tokenizer.decode(tokens[index:index + 1])
def use_last_tokens_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_slice = use
def use_last_tokens_attention_weights(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.use_last_attn_weights = use
def use_last_self_attention(unet, use=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.use_last_attn_slice = use
def save_last_tokens_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn2' in name:
module.save_last_attn_slice = save
def save_last_self_attention(unet, save=True):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == 'CrossAttention' and 'attn1' in name:
module.save_last_attn_slice = save