mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
285 lines
13 KiB
Python
285 lines
13 KiB
Python
import enum
|
|
import warnings
|
|
from typing import Optional
|
|
|
|
import torch
|
|
|
|
|
|
# adapted from bloc97's CrossAttentionControl colab
|
|
# https://github.com/bloc97/CrossAttentionControl
|
|
|
|
|
|
class Arguments:
|
|
def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict):
|
|
"""
|
|
:param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768]
|
|
:param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required)
|
|
:param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes.
|
|
"""
|
|
# todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector
|
|
self.edited_conditioning = edited_conditioning
|
|
self.edit_opcodes = edit_opcodes
|
|
|
|
if edited_conditioning is not None:
|
|
assert len(edit_opcodes) == len(edit_options), \
|
|
"there must be 1 edit_options dict for each edit_opcodes tuple"
|
|
non_none_edit_options = [x for x in edit_options if x is not None]
|
|
assert len(non_none_edit_options)>0, "missing edit_options"
|
|
if len(non_none_edit_options)>1:
|
|
print('warning: cross-attention control options are not working properly for >1 edit')
|
|
self.edit_options = non_none_edit_options[0]
|
|
|
|
|
|
class CrossAttentionType(enum.Enum):
|
|
SELF = 1
|
|
TOKENS = 2
|
|
|
|
|
|
class Context:
|
|
|
|
cross_attention_mask: Optional[torch.Tensor]
|
|
cross_attention_index_map: Optional[torch.Tensor]
|
|
|
|
class Action(enum.Enum):
|
|
NONE = 0
|
|
SAVE = 1,
|
|
APPLY = 2
|
|
|
|
def __init__(self, arguments: Arguments, step_count: int):
|
|
"""
|
|
:param arguments: Arguments for the cross-attention control process
|
|
:param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run)
|
|
"""
|
|
self.cross_attention_mask = None
|
|
self.cross_attention_index_map = None
|
|
self.self_cross_attention_action = Context.Action.NONE
|
|
self.tokens_cross_attention_action = Context.Action.NONE
|
|
self.arguments = arguments
|
|
self.step_count = step_count
|
|
|
|
self.self_cross_attention_module_identifiers = []
|
|
self.tokens_cross_attention_module_identifiers = []
|
|
|
|
self.saved_cross_attention_maps = {}
|
|
|
|
self.clear_requests(cleanup=True)
|
|
|
|
def register_cross_attention_modules(self, model):
|
|
for name,module in get_attention_modules(model, CrossAttentionType.SELF):
|
|
self.self_cross_attention_module_identifiers.append(name)
|
|
for name,module in get_attention_modules(model, CrossAttentionType.TOKENS):
|
|
self.tokens_cross_attention_module_identifiers.append(name)
|
|
|
|
def request_save_attention_maps(self, cross_attention_type: CrossAttentionType):
|
|
if cross_attention_type == CrossAttentionType.SELF:
|
|
self.self_cross_attention_action = Context.Action.SAVE
|
|
else:
|
|
self.tokens_cross_attention_action = Context.Action.SAVE
|
|
|
|
def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType):
|
|
if cross_attention_type == CrossAttentionType.SELF:
|
|
self.self_cross_attention_action = Context.Action.APPLY
|
|
else:
|
|
self.tokens_cross_attention_action = Context.Action.APPLY
|
|
|
|
def is_tokens_cross_attention(self, module_identifier) -> bool:
|
|
return module_identifier in self.tokens_cross_attention_module_identifiers
|
|
|
|
def get_should_save_maps(self, module_identifier: str) -> bool:
|
|
if module_identifier in self.self_cross_attention_module_identifiers:
|
|
return self.self_cross_attention_action == Context.Action.SAVE
|
|
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
|
return self.tokens_cross_attention_action == Context.Action.SAVE
|
|
return False
|
|
|
|
def get_should_apply_saved_maps(self, module_identifier: str) -> bool:
|
|
if module_identifier in self.self_cross_attention_module_identifiers:
|
|
return self.self_cross_attention_action == Context.Action.APPLY
|
|
elif module_identifier in self.tokens_cross_attention_module_identifiers:
|
|
return self.tokens_cross_attention_action == Context.Action.APPLY
|
|
return False
|
|
|
|
def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\
|
|
-> list[CrossAttentionType]:
|
|
"""
|
|
Should cross-attention control be applied on the given step?
|
|
:param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0.
|
|
:return: A list of attention types that cross-attention control should be performed for on the given step. May be [].
|
|
"""
|
|
if percent_through is None:
|
|
return [CrossAttentionType.SELF, CrossAttentionType.TOKENS]
|
|
|
|
opts = self.arguments.edit_options
|
|
to_control = []
|
|
if opts['s_start'] <= percent_through < opts['s_end']:
|
|
to_control.append(CrossAttentionType.SELF)
|
|
if opts['t_start'] <= percent_through < opts['t_end']:
|
|
to_control.append(CrossAttentionType.TOKENS)
|
|
return to_control
|
|
|
|
def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int,
|
|
slice_size: Optional[int]):
|
|
if identifier not in self.saved_cross_attention_maps:
|
|
self.saved_cross_attention_maps[identifier] = {
|
|
'dim': dim,
|
|
'slice_size': slice_size,
|
|
'slices': {offset or 0: slice}
|
|
}
|
|
else:
|
|
self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice
|
|
|
|
def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int):
|
|
saved_attention_dict = self.saved_cross_attention_maps[identifier]
|
|
if requested_dim is None:
|
|
if saved_attention_dict['dim'] is not None:
|
|
raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}")
|
|
return saved_attention_dict['slices'][0]
|
|
|
|
if saved_attention_dict['dim'] == requested_dim:
|
|
if slice_size != saved_attention_dict['slice_size']:
|
|
raise RuntimeError(
|
|
f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}")
|
|
return saved_attention_dict['slices'][requested_offset]
|
|
|
|
if saved_attention_dict['dim'] is None:
|
|
whole_saved_attention = saved_attention_dict['slices'][0]
|
|
if requested_dim == 0:
|
|
return whole_saved_attention[requested_offset:requested_offset + slice_size]
|
|
elif requested_dim == 1:
|
|
return whole_saved_attention[:, requested_offset:requested_offset + slice_size]
|
|
|
|
raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}")
|
|
|
|
def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]:
|
|
saved_attention = self.saved_cross_attention_maps.get(identifier, None)
|
|
if saved_attention is None:
|
|
return None, None
|
|
return saved_attention['dim'], saved_attention['slice_size']
|
|
|
|
def clear_requests(self, cleanup=True):
|
|
self.tokens_cross_attention_action = Context.Action.NONE
|
|
self.self_cross_attention_action = Context.Action.NONE
|
|
if cleanup:
|
|
self.saved_cross_attention_maps = {}
|
|
|
|
def offload_saved_attention_slices_to_cpu(self):
|
|
for key, map_dict in self.saved_cross_attention_maps.items():
|
|
for offset, slice in map_dict['slices'].items():
|
|
map_dict[offset] = slice.to('cpu')
|
|
|
|
|
|
def remove_cross_attention_control(model):
|
|
remove_attention_function(model)
|
|
|
|
|
|
def setup_cross_attention_control(model, context: Context):
|
|
"""
|
|
Inject attention parameters and functions into the passed in model to enable cross attention editing.
|
|
|
|
:param model: The unet model to inject into.
|
|
:param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations
|
|
:return: None
|
|
"""
|
|
|
|
# adapted from init_attention_edit
|
|
device = context.arguments.edited_conditioning.device
|
|
|
|
# urgh. should this be hardcoded?
|
|
max_length = 77
|
|
# mask=1 means use base prompt attention, mask=0 means use edited prompt attention
|
|
mask = torch.zeros(max_length)
|
|
indices_target = torch.arange(max_length, dtype=torch.long)
|
|
indices = torch.zeros(max_length, dtype=torch.long)
|
|
for name, a0, a1, b0, b1 in context.arguments.edit_opcodes:
|
|
if b0 < max_length:
|
|
if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0):
|
|
# these tokens have not been edited
|
|
indices[b0:b1] = indices_target[a0:a1]
|
|
mask[b0:b1] = 1
|
|
|
|
context.register_cross_attention_modules(model)
|
|
context.cross_attention_mask = mask.to(device)
|
|
context.cross_attention_index_map = indices.to(device)
|
|
inject_attention_function(model, context)
|
|
|
|
|
|
def get_attention_modules(model, which: CrossAttentionType):
|
|
which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2"
|
|
return [(name,module) for name, module in model.named_modules() if
|
|
type(module).__name__ == "CrossAttention" and which_attn in name]
|
|
|
|
|
|
def inject_attention_function(unet, context: Context):
|
|
# ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276
|
|
|
|
def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size):
|
|
|
|
#memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement()
|
|
|
|
attention_slice = suggested_attention_slice
|
|
|
|
if context.get_should_save_maps(module.identifier):
|
|
#print(module.identifier, "saving suggested_attention_slice of shape",
|
|
# suggested_attention_slice.shape, "dim", dim, "offset", offset)
|
|
slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice
|
|
context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size)
|
|
elif context.get_should_apply_saved_maps(module.identifier):
|
|
#print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset)
|
|
saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size)
|
|
|
|
# slice may have been offloaded to CPU
|
|
saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device)
|
|
|
|
if context.is_tokens_cross_attention(module.identifier):
|
|
index_map = context.cross_attention_index_map
|
|
remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map)
|
|
this_attention_slice = suggested_attention_slice
|
|
|
|
mask = context.cross_attention_mask
|
|
saved_mask = mask
|
|
this_mask = 1 - mask
|
|
attention_slice = remapped_saved_attention_slice * saved_mask + \
|
|
this_attention_slice * this_mask
|
|
else:
|
|
# just use everything
|
|
attention_slice = saved_attention_slice
|
|
|
|
return attention_slice
|
|
|
|
cross_attention_modules = [(name, module) for (name, module) in unet.named_modules()
|
|
if type(module).__name__ == "CrossAttention"]
|
|
for identifier, module in cross_attention_modules:
|
|
module.identifier = identifier
|
|
try:
|
|
module.set_attention_slice_wrangler(attention_slice_wrangler)
|
|
module.set_slicing_strategy_getter(
|
|
lambda module: context.get_slicing_strategy(identifier)
|
|
)
|
|
except AttributeError as e:
|
|
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
|
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
|
else:
|
|
raise
|
|
|
|
|
|
def remove_attention_function(unet):
|
|
cross_attention_modules = [module for (_, module) in unet.named_modules()
|
|
if type(module).__name__ == "CrossAttention"]
|
|
for module in cross_attention_modules:
|
|
try:
|
|
# clear wrangler callback
|
|
module.set_attention_slice_wrangler(None)
|
|
module.set_slicing_strategy_getter(None)
|
|
except AttributeError as e:
|
|
if is_attribute_error_about(e, 'set_attention_slice_wrangler'):
|
|
warnings.warn(f"TODO: implement for {type(module)}") # TODO
|
|
else:
|
|
raise
|
|
|
|
|
|
def is_attribute_error_about(error: AttributeError, attribute: str):
|
|
if hasattr(error, 'name'): # Python 3.10
|
|
return error.name == attribute
|
|
else: # Python 3.9
|
|
return attribute in str(error)
|