This commit is contained in:
Damian at mba 2022-10-18 13:52:40 +02:00
parent 056cb0d8a8
commit 711ffd238f
2 changed files with 43 additions and 28 deletions

View File

@ -6,14 +6,16 @@ import torch
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
class AttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def clear_attention_editing(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_attention_editing(cls, model,
substitute_conditioning: torch.Tensor = None,
edit_opcodes: list = None):
substitute_conditioning: torch.Tensor,
edit_opcodes: list):
"""
:param model: The unet model to inject into.
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
@ -23,31 +25,34 @@ class CrossAttentionControl:
"""
# adapted from init_attention_edit
if substitute_conditioning is not None:
device = substitute_conditioning.device
device = substitute_conditioning.device
max_length = model.inner_model.cond_stage_model.max_length
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
max_length = model.inner_model.cond_stage_model.max_length
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
mask = torch.zeros(max_length)
indices_target = torch.arange(max_length, dtype=torch.long)
indices = torch.zeros(max_length, dtype=torch.long)
for name, a0, a1, b0, b1 in edit_opcodes:
if b0 < max_length:
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
# these tokens have not been edited
indices[b0:b1] = indices_target[a0:a1]
mask[b0:b1] = 1
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
m.last_attn_slice_mask = None
m.last_attn_slice_indices = None
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
cls.inject_attention_function(model)
cls.inject_attention_functions(model)
class AttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
@ -79,8 +84,9 @@ class CrossAttentionControl:
m.use_last_attn_slice = True
@classmethod
def inject_attention_functions(cls, unet):
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
@ -166,11 +172,20 @@ class CrossAttentionControl:
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.use_last_attn_slice = False
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler)
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)
# original code below

View File

@ -44,7 +44,7 @@ class CFGDenoiser(nn.Module):
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
else:
# pass through the attention func but don't act on it
CrossAttentionControl.setup_attention_editing(self.inner_model)
CrossAttentionControl.clear_attention_editing(self.inner_model)
def forward(self, x, sigma, uncond, cond, cond_scale):