mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
cleanup
This commit is contained in:
parent
056cb0d8a8
commit
711ffd238f
@ -6,14 +6,16 @@ import torch
|
|||||||
# https://github.com/bloc97/CrossAttentionControl
|
# https://github.com/bloc97/CrossAttentionControl
|
||||||
|
|
||||||
class CrossAttentionControl:
|
class CrossAttentionControl:
|
||||||
class AttentionType(Enum):
|
|
||||||
SELF = 1
|
|
||||||
TOKENS = 2
|
@classmethod
|
||||||
|
def clear_attention_editing(cls, model):
|
||||||
|
cls.remove_attention_function(model)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setup_attention_editing(cls, model,
|
def setup_attention_editing(cls, model,
|
||||||
substitute_conditioning: torch.Tensor = None,
|
substitute_conditioning: torch.Tensor,
|
||||||
edit_opcodes: list = None):
|
edit_opcodes: list):
|
||||||
"""
|
"""
|
||||||
:param model: The unet model to inject into.
|
:param model: The unet model to inject into.
|
||||||
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
|
:param substitute_conditioning: The "edited" conditioning vector, [Bx77x768]
|
||||||
@ -23,31 +25,34 @@ class CrossAttentionControl:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# adapted from init_attention_edit
|
# adapted from init_attention_edit
|
||||||
if substitute_conditioning is not None:
|
device = substitute_conditioning.device
|
||||||
|
|
||||||
device = substitute_conditioning.device
|
max_length = model.inner_model.cond_stage_model.max_length
|
||||||
|
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
||||||
|
mask = torch.zeros(max_length)
|
||||||
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
||||||
|
indices = torch.zeros(max_length, dtype=torch.long)
|
||||||
|
for name, a0, a1, b0, b1 in edit_opcodes:
|
||||||
|
if b0 < max_length:
|
||||||
|
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
||||||
|
# these tokens have not been edited
|
||||||
|
indices[b0:b1] = indices_target[a0:a1]
|
||||||
|
mask[b0:b1] = 1
|
||||||
|
|
||||||
max_length = model.inner_model.cond_stage_model.max_length
|
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
|
||||||
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
m.last_attn_slice_mask = None
|
||||||
mask = torch.zeros(max_length)
|
m.last_attn_slice_indices = None
|
||||||
indices_target = torch.arange(max_length, dtype=torch.long)
|
|
||||||
indices = torch.zeros(max_length, dtype=torch.long)
|
|
||||||
for name, a0, a1, b0, b1 in edit_opcodes:
|
|
||||||
if b0 < max_length:
|
|
||||||
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
|
||||||
# these tokens have not been edited
|
|
||||||
indices[b0:b1] = indices_target[a0:a1]
|
|
||||||
mask[b0:b1] = 1
|
|
||||||
|
|
||||||
for m in cls.get_attention_modules(model, cls.AttentionType.SELF):
|
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
|
||||||
m.last_attn_slice_mask = None
|
m.last_attn_slice_mask = mask.to(device)
|
||||||
m.last_attn_slice_indices = None
|
m.last_attn_slice_indices = indices.to(device)
|
||||||
|
|
||||||
for m in cls.get_attention_modules(model, cls.AttentionType.TOKENS):
|
cls.inject_attention_function(model)
|
||||||
m.last_attn_slice_mask = mask.to(device)
|
|
||||||
m.last_attn_slice_indices = indices.to(device)
|
|
||||||
|
|
||||||
cls.inject_attention_functions(model)
|
|
||||||
|
class AttentionType(Enum):
|
||||||
|
SELF = 1
|
||||||
|
TOKENS = 2
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -79,8 +84,9 @@ class CrossAttentionControl:
|
|||||||
m.use_last_attn_slice = True
|
m.use_last_attn_slice = True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def inject_attention_functions(cls, unet):
|
def inject_attention_function(cls, unet):
|
||||||
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
||||||
|
|
||||||
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size):
|
||||||
@ -166,11 +172,20 @@ class CrossAttentionControl:
|
|||||||
module_name = type(module).__name__
|
module_name = type(module).__name__
|
||||||
if module_name == "CrossAttention":
|
if module_name == "CrossAttention":
|
||||||
module.last_attn_slice = None
|
module.last_attn_slice = None
|
||||||
module.use_last_attn_slice = False
|
module.last_attn_slice_indices = None
|
||||||
|
module.last_attn_slice_mask = None
|
||||||
module.use_last_attn_weights = False
|
module.use_last_attn_weights = False
|
||||||
|
module.use_last_attn_slice = False
|
||||||
module.save_last_attn_slice = False
|
module.save_last_attn_slice = False
|
||||||
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def remove_attention_function(cls, unet):
|
||||||
|
for name, module in unet.named_modules():
|
||||||
|
module_name = type(module).__name__
|
||||||
|
if module_name == "CrossAttention":
|
||||||
|
module.set_attention_slice_wrangler(None)
|
||||||
|
|
||||||
|
|
||||||
# original code below
|
# original code below
|
||||||
|
|
||||||
|
@ -44,7 +44,7 @@ class CFGDenoiser(nn.Module):
|
|||||||
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
|
CrossAttentionControl.setup_attention_editing(self.inner_model, edited_conditioning, edit_opcodes)
|
||||||
else:
|
else:
|
||||||
# pass through the attention func but don't act on it
|
# pass through the attention func but don't act on it
|
||||||
CrossAttentionControl.setup_attention_editing(self.inner_model)
|
CrossAttentionControl.clear_attention_editing(self.inner_model)
|
||||||
|
|
||||||
def forward(self, x, sigma, uncond, cond, cond_scale):
|
def forward(self, x, sigma, uncond, cond, cond_scale):
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user