2022-10-16 14:57:48 +00:00
from enum import Enum
2022-10-16 18:39:47 +00:00
import torch
2022-10-16 14:57:48 +00:00
2022-10-18 20:09:06 +00:00
class CrossAttentionControllableDiffusionMixin :
def setup_cross_attention_control_if_appropriate ( self , model , edited_conditioning , edit_opcodes ) :
self . edited_conditioning = edited_conditioning
if edited_conditioning is not None :
# <start> a cat sitting on a car <end>
CrossAttentionControl . setup_attention_editing ( model , edited_conditioning , edit_opcodes )
else :
# pass through the attention func but don't act on it
CrossAttentionControl . clear_attention_editing ( model )
def cleanup_cross_attention_control ( self , model ) :
CrossAttentionControl . clear_attention_editing ( model )
def do_cross_attention_controllable_diffusion_step ( self , x , sigma , unconditioning , conditioning , model , model_forward_callback ) :
CrossAttentionControl . clear_requests ( model )
if self . edited_conditioning is None :
# faster batched path
x_twice = torch . cat ( [ x ] * 2 )
sigma_twice = torch . cat ( [ sigma ] * 2 )
both_conditionings = torch . cat ( [ unconditioning , conditioning ] )
unconditioned_next_x , conditioned_next_x = model_forward_callback ( x_twice , sigma_twice , both_conditionings ) . chunk ( 2 )
else :
# slower non-batched path (20% slower on mac MPS)
# We are only interested in using attention maps for conditioned_next_x, but batching them with generation of
# unconditioned_next_x causes attention maps to *also* be saved for the unconditioned_next_x.
# This messes app their application later, due to mismatched shape of dim 0 (seems to be 16 for batched vs. 8)
# (For the batched invocation the `wrangler` function gets attention tensor with shape[0]=16,
# representing batched uncond + cond, but then when it comes to applying the saved attention, the
# wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.)
# todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well.
unconditioned_next_x = model_forward_callback ( x , sigma , unconditioning )
# process x using the original prompt, saving the attention maps
CrossAttentionControl . request_save_attention_maps ( model )
_ = model_forward_callback ( x , sigma , cond = conditioning )
CrossAttentionControl . clear_requests ( model )
# process x again, using the saved attention maps to control where self.edited_conditioning will be applied
CrossAttentionControl . request_apply_saved_attention_maps ( model )
conditioned_next_x = model_forward_callback ( x , sigma , self . edited_conditioning )
CrossAttentionControl . clear_requests ( model )
return unconditioned_next_x , conditioned_next_x
2022-10-17 19:15:03 +00:00
# adapted from bloc97's CrossAttentionControl colab
# https://github.com/bloc97/CrossAttentionControl
2022-10-16 18:39:47 +00:00
class CrossAttentionControl :
2022-10-18 11:52:40 +00:00
@classmethod
def clear_attention_editing ( cls , model ) :
cls . remove_attention_function ( model )
2022-10-16 14:57:48 +00:00
@classmethod
2022-10-17 19:15:03 +00:00
def setup_attention_editing ( cls , model ,
2022-10-18 11:52:40 +00:00
substitute_conditioning : torch . Tensor ,
edit_opcodes : list ) :
2022-10-17 19:15:03 +00:00
"""
: param model : The unet model to inject into .
: param substitute_conditioning : The " edited " conditioning vector , [ Bx77x768 ]
: param edit_opcodes : Opcodes from difflib . SequenceMatcher describing how the base
conditionings map to the " edited " conditionings .
: return :
"""
2022-10-16 18:39:47 +00:00
# adapted from init_attention_edit
2022-10-18 11:52:40 +00:00
device = substitute_conditioning . device
2022-10-18 20:09:06 +00:00
# urgh. should this be hardcoded?
max_length = 77
2022-10-18 11:52:40 +00:00
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch . zeros ( max_length )
indices_target = torch . arange ( max_length , dtype = torch . long )
indices = torch . zeros ( max_length , dtype = torch . long )
for name , a0 , a1 , b0 , b1 in edit_opcodes :
if b0 < max_length :
if name == " equal " : # or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices [ b0 : b1 ] = indices_target [ a0 : a1 ]
mask [ b0 : b1 ] = 1
2022-10-16 18:39:47 +00:00
2022-10-18 11:52:40 +00:00
for m in cls . get_attention_modules ( model , cls . AttentionType . SELF ) :
m . last_attn_slice_mask = None
m . last_attn_slice_indices = None
2022-10-16 18:39:47 +00:00
2022-10-18 11:52:40 +00:00
for m in cls . get_attention_modules ( model , cls . AttentionType . TOKENS ) :
m . last_attn_slice_mask = mask . to ( device )
m . last_attn_slice_indices = indices . to ( device )
2022-10-17 19:15:03 +00:00
2022-10-18 11:52:40 +00:00
cls . inject_attention_function ( model )
2022-10-17 19:15:03 +00:00
2022-10-16 18:39:47 +00:00
2022-10-18 11:52:40 +00:00
class AttentionType ( Enum ) :
SELF = 1
TOKENS = 2
2022-10-16 18:39:47 +00:00
2022-10-17 19:15:03 +00:00
@classmethod
def get_attention_modules ( cls , model , which : AttentionType ) :
which_attn = " attn1 " if which is cls . AttentionType . SELF else " attn2 "
return [ module for name , module in model . named_modules ( ) if
type ( module ) . __name__ == " CrossAttention " and which_attn in name ]
2022-10-18 09:48:33 +00:00
@classmethod
def clear_requests ( cls , model ) :
self_attention_modules = cls . get_attention_modules ( model , cls . AttentionType . SELF )
tokens_attention_modules = cls . get_attention_modules ( model , cls . AttentionType . TOKENS )
for m in self_attention_modules + tokens_attention_modules :
m . save_last_attn_slice = False
m . use_last_attn_slice = False
2022-10-17 19:15:03 +00:00
2022-10-16 18:39:47 +00:00
@classmethod
def request_save_attention_maps ( cls , model ) :
2022-10-17 19:15:03 +00:00
self_attention_modules = cls . get_attention_modules ( model , cls . AttentionType . SELF )
tokens_attention_modules = cls . get_attention_modules ( model , cls . AttentionType . TOKENS )
for m in self_attention_modules + tokens_attention_modules :
2022-10-18 17:49:25 +00:00
# clear out the saved slice in case the outermost dim changes
m . last_attn_slice = None
2022-10-17 19:15:03 +00:00
m . save_last_attn_slice = True
2022-10-16 18:39:47 +00:00
@classmethod
def request_apply_saved_attention_maps ( cls , model ) :
2022-10-17 19:15:03 +00:00
self_attention_modules = cls . get_attention_modules ( model , cls . AttentionType . SELF )
tokens_attention_modules = cls . get_attention_modules ( model , cls . AttentionType . TOKENS )
for m in self_attention_modules + tokens_attention_modules :
m . use_last_attn_slice = True
2022-10-16 18:39:47 +00:00
2022-10-18 11:52:40 +00:00
2022-10-16 18:39:47 +00:00
@classmethod
2022-10-18 11:52:40 +00:00
def inject_attention_function ( cls , unet ) :
2022-10-16 18:39:47 +00:00
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
2022-10-17 23:54:30 +00:00
2022-10-18 09:48:33 +00:00
def attention_slice_wrangler ( self , attention_scores , suggested_attention_slice , dim , offset , slice_size ) :
2022-10-18 17:49:25 +00:00
#print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim)
2022-10-18 09:48:33 +00:00
attn_slice = suggested_attention_slice
if dim is not None :
start = offset
end = start + slice_size
#print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
#else:
# print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}")
2022-10-17 23:54:30 +00:00
2022-10-16 18:39:47 +00:00
if self . use_last_attn_slice :
2022-10-18 09:48:33 +00:00
this_attn_slice = attn_slice
2022-10-16 18:39:47 +00:00
if self . last_attn_slice_mask is not None :
2022-10-18 09:48:33 +00:00
# indices and mask operate on dim=2, no need to slice
base_attn_slice_full = torch . index_select ( self . last_attn_slice , - 1 , self . last_attn_slice_indices )
2022-10-17 19:15:03 +00:00
base_attn_slice_mask = self . last_attn_slice_mask
2022-10-18 09:48:33 +00:00
if dim is None :
base_attn_slice = base_attn_slice_full
#print("using whole base slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 0 :
base_attn_slice = base_attn_slice_full [ start : end ]
#print("using base dim 0 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
elif dim == 1 :
base_attn_slice = base_attn_slice_full [ : , start : end ]
#print("using base dim 1 slice of shape", base_attn_slice.shape, "from complete shape", base_attn_slice_full.shape)
attn_slice = this_attn_slice * ( 1 - base_attn_slice_mask ) + \
base_attn_slice * base_attn_slice_mask
2022-10-16 18:39:47 +00:00
else :
2022-10-18 09:48:33 +00:00
if dim is None :
2022-10-17 19:15:03 +00:00
attn_slice = self . last_attn_slice
2022-10-18 09:48:33 +00:00
#print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 0 :
attn_slice = self . last_attn_slice [ start : end ]
#print("took dim 0 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
elif dim == 1 :
attn_slice = self . last_attn_slice [ : , start : end ]
#print("took dim 1 slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape)
2022-10-17 19:15:03 +00:00
2022-10-18 09:48:33 +00:00
if self . save_last_attn_slice :
if dim is None :
2022-10-17 19:15:03 +00:00
self . last_attn_slice = attn_slice
2022-10-18 09:48:33 +00:00
elif dim == 0 :
# dynamically grow last_attn_slice if needed
if self . last_attn_slice is None :
self . last_attn_slice = attn_slice
#print("no last_attn_slice: shape now", self.last_attn_slice.shape)
elif self . last_attn_slice . shape [ 0 ] == start :
self . last_attn_slice = torch . cat ( [ self . last_attn_slice , attn_slice ] , dim = 0 )
assert ( self . last_attn_slice . shape [ 0 ] == end )
#print("last_attn_slice too small, appended dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
else :
# no need to grow
self . last_attn_slice [ start : end ] = attn_slice
#print("last_attn_slice shape is fine, setting dim 0 shape", attn_slice.shape, ", shape now", self.last_attn_slice.shape)
elif dim == 1 :
# dynamically grow last_attn_slice if needed
if self . last_attn_slice is None :
self . last_attn_slice = attn_slice
elif self . last_attn_slice . shape [ 1 ] == start :
self . last_attn_slice = torch . cat ( [ self . last_attn_slice , attn_slice ] , dim = 1 )
assert ( self . last_attn_slice . shape [ 1 ] == end )
else :
# no need to grow
self . last_attn_slice [ : , start : end ] = attn_slice
2022-10-17 19:15:03 +00:00
2022-10-18 09:48:33 +00:00
if self . use_last_attn_weights and self . last_attn_slice_weights is not None :
if dim is None :
weights = self . last_attn_slice_weights
elif dim == 0 :
weights = self . last_attn_slice_weights [ start : end ]
elif dim == 1 :
weights = self . last_attn_slice_weights [ : , start : end ]
attn_slice = attn_slice * weights
2022-10-16 18:39:47 +00:00
2022-10-18 09:48:33 +00:00
return attn_slice
2022-10-17 23:54:30 +00:00
2022-10-17 19:15:03 +00:00
for name , module in unet . named_modules ( ) :
2022-10-16 18:39:47 +00:00
module_name = type ( module ) . __name__
2022-10-17 19:15:03 +00:00
if module_name == " CrossAttention " :
2022-10-16 18:39:47 +00:00
module . last_attn_slice = None
2022-10-18 11:52:40 +00:00
module . last_attn_slice_indices = None
module . last_attn_slice_mask = None
2022-10-16 18:39:47 +00:00
module . use_last_attn_weights = False
2022-10-18 11:52:40 +00:00
module . use_last_attn_slice = False
2022-10-16 18:39:47 +00:00
module . save_last_attn_slice = False
2022-10-18 09:48:33 +00:00
module . set_attention_slice_wrangler ( attention_slice_wrangler )
2022-10-16 18:39:47 +00:00
2022-10-18 11:52:40 +00:00
@classmethod
def remove_attention_function ( cls , unet ) :
for name , module in unet . named_modules ( ) :
module_name = type ( module ) . __name__
if module_name == " CrossAttention " :
module . set_attention_slice_wrangler ( None )
2022-10-16 18:39:47 +00:00
# original code below
# Functions supporting Cross-Attention Control
# Copied from https://github.com/bloc97/CrossAttentionControl
from difflib import SequenceMatcher
import torch
def prompt_token ( prompt , index , clip_tokenizer ) :
tokens = clip_tokenizer ( prompt ,
padding = ' max_length ' ,
max_length = clip_tokenizer . model_max_length ,
truncation = True ,
return_tensors = ' pt ' ,
return_overflowing_tokens = True
) . input_ids [ 0 ]
return clip_tokenizer . decode ( tokens [ index : index + 1 ] )
def use_last_tokens_attention ( unet , use = True ) :
for name , module in unet . named_modules ( ) :
module_name = type ( module ) . __name__
if module_name == ' CrossAttention ' and ' attn2 ' in name :
module . use_last_attn_slice = use
def use_last_tokens_attention_weights ( unet , use = True ) :
for name , module in unet . named_modules ( ) :
module_name = type ( module ) . __name__
if module_name == ' CrossAttention ' and ' attn2 ' in name :
module . use_last_attn_weights = use
def use_last_self_attention ( unet , use = True ) :
for name , module in unet . named_modules ( ) :
module_name = type ( module ) . __name__
if module_name == ' CrossAttention ' and ' attn1 ' in name :
module . use_last_attn_slice = use
def save_last_tokens_attention ( unet , save = True ) :
for name , module in unet . named_modules ( ) :
module_name = type ( module ) . __name__
if module_name == ' CrossAttention ' and ' attn2 ' in name :
module . save_last_attn_slice = save
def save_last_self_attention ( unet , save = True ) :
for name , module in unet . named_modules ( ) :
module_name = type ( module ) . __name__
if module_name == ' CrossAttention ' and ' attn1 ' in name :
module . save_last_attn_slice = save