mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup
This commit is contained in:
parent
056cb0d8a8
commit
711ffd238f
@ -6,14 +6,16 @@ import torch
|
||||
# https://github.com/bloc97/CrossAttentionControl
|
||||
|
||||
class CrossAttentionControl:
|
||||
class AttentionType(Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
@classmethod
|
||||
def clear_attention_editing(cls, model):
|
||||
cls.remove_attention_function(model)
|
||||
|
||||
@classmethod
|
||||
def setup_attention_editing(cls, model,
|
||||
substitute_conditioning: torch.Tensor = None,
|
||||
edit_opcodes: list = None):
|
||||
substitute_conditioning: torch.Tensor,
|
||||
edit_opcodes: list):
|
||||
"""
|
||||
:param model: The unet model to inject into.
|
||||
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
|
||||
@ -23,31 +25,34 @@ class CrossAttentionControl:
|
||||
"""
|
||||
|
||||
# adapted from init_attention_edit
|
||||
if substitute_conditioning is not None:
|
||||
device = substitute_conditioning.device
|
||||
|
||||
device = substitute_conditioning.device
|
||||
max_length = model.inner_model.cond_stage_model.max_length
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
|
||||
max_length = model.inner_model.cond_stage_model.max_length
|
||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||
mask = torch.zeros(max_length)
|
||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||
indices = torch.zeros(max_length, dtype=torch.long)
|
||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||
if b0 < max_length:
|
||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||
# these tokens have not been edited
|
||||
indices[b0:b1] = indices_target[a0:a1]
|
||||
mask[b0:b1] = 1
|
||||
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
|
||||
m.last_attn_slice_mask = None
|
||||
m.last_attn_slice_indices = None
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
|
||||
m.last_attn_slice_mask = None
|
||||
m.last_attn_slice_indices = None
|
||||
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
|
||||
m.last_attn_slice_mask = mask.to(device)
|
||||
m.last_attn_slice_indices = indices.to(device)
|
||||
|
||||
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
|
||||
m.last_attn_slice_mask = mask.to(device)
|
||||
m.last_attn_slice_indices = indices.to(device)
|
||||
cls.inject_attention_function(model)
|
||||
|
||||
cls.inject_attention_functions(model)
|
||||
|
||||
class AttentionType(Enum):
|
||||
SELF = 1
|
||||
TOKENS = 2
|
||||
|
||||
|
||||
@classmethod
|
||||
@ -79,8 +84,9 @@ class CrossAttentionControl:
|
||||
m.use_last_attn_slice = True
|
||||
|
||||
|
||||
|
||||
@classmethod
|
||||
def inject_attention_functions(cls, unet):
|
||||
def inject_attention_function(cls, unet):
|
||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||
|
||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||
@ -166,11 +172,20 @@ class CrossAttentionControl:
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.last_attn_slice = None
|
||||
module.use_last_attn_slice = False
|
||||
module.last_attn_slice_indices = None
|
||||
module.last_attn_slice_mask = None
|
||||
module.use_last_attn_weights = False
|
||||
module.use_last_attn_slice = False
|
||||
module.save_last_attn_slice = False
|
||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||
|
||||
@classmethod
|
||||
def remove_attention_function(cls, unet):
|
||||
for name, module in unet.named_modules():
|
||||
module_name = type(module).__name__
|
||||
if module_name == "CrossAttention":
|
||||
module.set_attention_slice_wrangler(None)
|
||||
|
||||
|
||||
# original code below
|
||||
|
||||
|
@ -44,7 +44,7 @@ class CFGDenoiser(nn.Module):
|
||||
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
|
||||
else:
|
||||
# pass through the attention func but don't act on it
|
||||
CrossAttentionControl.setup_attention_editing(self.inner_model)
|
||||
CrossAttentionControl.clear_attention_editing(self.inner_model)
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user