This commit is contained in:
Damian at mba 2022-10-18 13:52:40 +02:00
parent 056cb0d8a8
commit 711ffd238f
2 changed files with 43 additions and 28 deletions

View File

@ -6,14 +6,16 @@ import torch
# https://github.com/bloc97/CrossAttentionControl
class CrossAttentionControl:
class AttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
def clear_attention_editing(cls, model):
cls.remove_attention_function(model)
@classmethod
def setup_attention_editing(cls, model,
substitute_conditioning: torch.Tensor = None,
edit_opcodes: list = None):
substitute_conditioning: torch.Tensor,
edit_opcodes: list):
"""
:param model: The unet model to inject into.
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
@ -23,8 +25,6 @@ class CrossAttentionControl:
"""
# adapted from init_attention_edit
if substitute_conditioning is not None:
device = substitute_conditioning.device
max_length = model.inner_model.cond_stage_model.max_length
@ -47,7 +47,12 @@ class CrossAttentionControl:
m.last_attn_slice_mask = mask.to(device)
m.last_attn_slice_indices = indices.to(device)
cls.inject_attention_functions(model)
cls.inject_attention_function(model)
class AttentionType(Enum):
SELF = 1
TOKENS = 2
@classmethod
@ -79,8 +84,9 @@ class CrossAttentionControl:
m.use_last_attn_slice = True
@classmethod
def inject_attention_functions(cls, unet):
def inject_attention_function(cls, unet):
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
@ -166,11 +172,20 @@ class CrossAttentionControl:
module_name = type(module).__name__
if module_name == "CrossAttention":
module.last_attn_slice = None
module.use_last_attn_slice = False
module.last_attn_slice_indices = None
module.last_attn_slice_mask = None
module.use_last_attn_weights = False
module.use_last_attn_slice = False
module.save_last_attn_slice = False
module.set_attention_slice_wrangler(attention_slice_wrangler)
@classmethod
def remove_attention_function(cls, unet):
for name, module in unet.named_modules():
module_name = type(module).__name__
if module_name == "CrossAttention":
module.set_attention_slice_wrangler(None)
# original code below

View File

@ -44,7 +44,7 @@ class CFGDenoiser(nn.Module):
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
else:
# pass through the attention func but don't act on it
CrossAttentionControl.setup_attention_editing(self.inner_model)
CrossAttentionControl.clear_attention_editing(self.inner_model)
def forward(self, x, sigma, uncond, cond, cond_scale):