# adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl import enum import math from typing import Optional, Callable import psutil import torch import diffusers from torch import nn from compel.cross_attention_control import Arguments from diffusers.models.unet_2d_condition import UNet2DConditionModel from diffusers.models.cross_attention import AttnProcessor from ldm.invoke.devices import torch_dtype class CrossAttentionType(enum.Enum): SELF = 1 TOKENS = 2 class Context: cross_attention_mask: Optional[torch.Tensor] cross_attention_index_map: Optional[torch.Tensor] class Action(enum.Enum): NONE = 0 SAVE = 1, APPLY = 2 def __init__(self, arguments: Arguments, step_count: int): """ :param arguments: Arguments for the cross-attention control process :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) """ self.cross_attention_mask = None self.cross_attention_index_map = None self.self_cross_attention_action = Context.Action.NONE self.tokens_cross_attention_action = Context.Action.NONE self.arguments = arguments self.step_count = step_count self.self_cross_attention_module_identifiers = [] self.tokens_cross_attention_module_identifiers = [] self.saved_cross_attention_maps = {} self.clear_requests(cleanup=True) def register_cross_attention_modules(self, model): for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF): if name in self.self_cross_attention_module_identifiers: assert False, f"name {name} cannot appear more than once" self.self_cross_attention_module_identifiers.append(name) for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): if name in self.tokens_cross_attention_module_identifiers: assert False, f"name {name} cannot appear more than once" self.tokens_cross_attention_module_identifiers.append(name) def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): if cross_attention_type == CrossAttentionType.SELF: self.self_cross_attention_action = Context.Action.SAVE else: self.tokens_cross_attention_action = Context.Action.SAVE def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType): if cross_attention_type == CrossAttentionType.SELF: self.self_cross_attention_action = Context.Action.APPLY else: self.tokens_cross_attention_action = Context.Action.APPLY def is_tokens_cross_attention(self, module_identifier) -> bool: return module_identifier in self.tokens_cross_attention_module_identifiers def get_should_save_maps(self, module_identifier: str) -> bool: if module_identifier in self.self_cross_attention_module_identifiers: return self.self_cross_attention_action == Context.Action.SAVE elif module_identifier in self.tokens_cross_attention_module_identifiers: return self.tokens_cross_attention_action == Context.Action.SAVE return False def get_should_apply_saved_maps(self, module_identifier: str) -> bool: if module_identifier in self.self_cross_attention_module_identifiers: return self.self_cross_attention_action == Context.Action.APPLY elif module_identifier in self.tokens_cross_attention_module_identifiers: return self.tokens_cross_attention_action == Context.Action.APPLY return False def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ -> list[CrossAttentionType]: """ Should cross-attention control be applied on the given step? :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. """ if percent_through is None: return [CrossAttentionType.SELF, CrossAttentionType.TOKENS] opts = self.arguments.edit_options to_control = [] if opts['s_start'] <= percent_through < opts['s_end']: to_control.append(CrossAttentionType.SELF) if opts['t_start'] <= percent_through < opts['t_end']: to_control.append(CrossAttentionType.TOKENS) return to_control def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, slice_size: Optional[int]): if identifier not in self.saved_cross_attention_maps: self.saved_cross_attention_maps[identifier] = { 'dim': dim, 'slice_size': slice_size, 'slices': {offset or 0: slice} } else: self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int): saved_attention_dict = self.saved_cross_attention_maps[identifier] if requested_dim is None: if saved_attention_dict['dim'] is not None: raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") return saved_attention_dict['slices'][0] if saved_attention_dict['dim'] == requested_dim: if slice_size != saved_attention_dict['slice_size']: raise RuntimeError( f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") return saved_attention_dict['slices'][requested_offset] if saved_attention_dict['dim'] is None: whole_saved_attention = saved_attention_dict['slices'][0] if requested_dim == 0: return whole_saved_attention[requested_offset:requested_offset + slice_size] elif requested_dim == 1: return whole_saved_attention[:, requested_offset:requested_offset + slice_size] raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]: saved_attention = self.saved_cross_attention_maps.get(identifier, None) if saved_attention is None: return None, None return saved_attention['dim'], saved_attention['slice_size'] def clear_requests(self, cleanup=True): self.tokens_cross_attention_action = Context.Action.NONE self.self_cross_attention_action = Context.Action.NONE if cleanup: self.saved_cross_attention_maps = {} def offload_saved_attention_slices_to_cpu(self): for key, map_dict in self.saved_cross_attention_maps.items(): for offset, slice in map_dict['slices'].items(): map_dict[offset] = slice.to('cpu') class InvokeAICrossAttentionMixin: """ Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling and dymamic slicing strategy selection. """ def __init__(self): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) self.attention_slice_wrangler = None self.slicing_strategy_getter = None self.attention_slice_calculated_callback = None def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]): ''' Set custom attention calculator to be called when attention is calculated :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), which returns either the suggested_attention_slice or an adjusted equivalent. `module` is the current CrossAttention module for which the callback is being invoked. `suggested_attention_slice` is the default-calculated attention slice `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. Pass None to use the default attention calculation. :return: ''' self.attention_slice_wrangler = wrangler def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]): self.slicing_strategy_getter = getter def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]): self.attention_slice_calculated_callback = callback def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): # calculate attention scores #attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) attention_scores = torch.baddbmm( torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), query, key.transpose(-1, -2), beta=0, alpha=self.scale, ) # calculate attention slice by taking the best scores for each latent pixel default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) attention_slice_wrangler = self.attention_slice_wrangler if attention_slice_wrangler is not None: attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) else: attention_slice = default_attention_slice if self.attention_slice_calculated_callback is not None: self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size) hidden_states = torch.bmm(attention_slice, value) return hidden_states def einsum_op_slice_dim0(self, q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): end = i + slice_size r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) return r def einsum_op_slice_dim1(self, q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): end = i + slice_size r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) return r def einsum_op_mps_v1(self, q, k, v): if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 return self.einsum_lowest_level(q, k, v, None, None, None) else: slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) return self.einsum_op_slice_dim1(q, k, v, slice_size) def einsum_op_mps_v2(self, q, k, v): if self.mem_total_gb > 8 and q.shape[1] <= 4096: return self.einsum_lowest_level(q, k, v, None, None, None) else: return self.einsum_op_slice_dim0(q, k, v, 1) def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: return self.einsum_lowest_level(q, k, v, None, None, None) div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() if div <= q.shape[0]: return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) def einsum_op_cuda(self, q, k, v): # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) slicing_strategy_getter = self.slicing_strategy_getter if slicing_strategy_getter is not None: (dim, slice_size) = slicing_strategy_getter(self) if dim is not None: # print("using saved slicing strategy with dim", dim, "slice size", slice_size) if dim == 0: return self.einsum_op_slice_dim0(q, k, v, slice_size) elif dim == 1: return self.einsum_op_slice_dim1(q, k, v, slice_size) # fallback for when there is no saved strategy, or saved strategy does not slice mem_free_total = get_mem_free_total(q.device) # Divide factor of safety as there's copying and fragmentation return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) def get_invokeai_attention_mem_efficient(self, q, k, v): if q.device.type == 'cuda': #print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) return self.einsum_op_cuda(q, k, v) if q.device.type == 'mps' or q.device.type == 'cpu': if self.mem_total_gb >= 32: return self.einsum_op_mps_v1(q, k, v) return self.einsum_op_mps_v2(q, k, v) # Smaller slices are faster due to L2/L3/SLC caches. # Tested on i7 with 8MB L3 cache. return self.einsum_op_tensor_mem(q, k, v, 32) def restore_default_cross_attention(model, is_running_diffusers: bool, restore_attention_processor: Optional[AttnProcessor]=None): if is_running_diffusers: unet = model unet.set_attn_processor(restore_attention_processor or CrossAttnProcessor()) else: remove_attention_function(model) def override_cross_attention(model, context: Context, is_running_diffusers = False): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. :param model: The unet model to inject into. :return: None """ # adapted from init_attention_edit device = context.arguments.edited_conditioning.device # urgh. should this be hardcoded? max_length = 77 # mask=1 means use base prompt attention, mask=0 means use edited prompt attention mask = torch.zeros(max_length, dtype=torch_dtype(device)) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: if b0 < max_length: if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 context.cross_attention_mask = mask.to(device) context.cross_attention_index_map = indices.to(device) if is_running_diffusers: unet = model old_attn_processors = unet.attn_processors if torch.backends.mps.is_available(): # see note in StableDiffusionGeneratorPipeline.__init__ about borked slicing on MPS unet.set_attn_processor(SwapCrossAttnProcessor()) else: # try to re-use an existing slice size default_slice_size = 4 slice_size = next((p.slice_size for p in old_attn_processors.values() if type(p) is SlicedAttnProcessor), default_slice_size) unet.set_attn_processor(SlicedSwapCrossAttnProcesser(slice_size=slice_size)) return old_attn_processors else: context.register_cross_attention_modules(model) inject_attention_function(model, context) return None def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: from ldm.modules.attention import CrossAttention # avoid circular import cross_attention_class: type = InvokeAIDiffusersCrossAttention if isinstance(model,UNet2DConditionModel) else CrossAttention which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" attention_module_tuples = [(name,module) for name, module in model.named_modules() if isinstance(module, cross_attention_class) and which_attn in name] cross_attention_modules_in_model_count = len(attention_module_tuples) expected_count = 16 if cross_attention_modules_in_model_count != expected_count: # non-fatal error but .swap() won't work. print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " + f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " + f"work properly until it is fixed.") return attention_module_tuples def inject_attention_function(unet, context: Context): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): #memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() attention_slice = suggested_attention_slice if context.get_should_save_maps(module.identifier): #print(module.identifier, "saving suggested_attention_slice of shape", # suggested_attention_slice.shape, "dim", dim, "offset", offset) slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size) elif context.get_should_apply_saved_maps(module.identifier): #print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) # slice may have been offloaded to CPU saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) if context.is_tokens_cross_attention(module.identifier): index_map = context.cross_attention_index_map remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) this_attention_slice = suggested_attention_slice mask = context.cross_attention_mask.to(torch_dtype(suggested_attention_slice.device)) saved_mask = mask this_mask = 1 - mask attention_slice = remapped_saved_attention_slice * saved_mask + \ this_attention_slice * this_mask else: # just use everything attention_slice = saved_attention_slice return attention_slice cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF) for identifier, module in cross_attention_modules: module.identifier = identifier try: module.set_attention_slice_wrangler(attention_slice_wrangler) module.set_slicing_strategy_getter( lambda module: context.get_slicing_strategy(identifier) ) except AttributeError as e: if is_attribute_error_about(e, 'set_attention_slice_wrangler'): print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO else: raise def remove_attention_function(unet): cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF) for identifier, module in cross_attention_modules: try: # clear wrangler callback module.set_attention_slice_wrangler(None) module.set_slicing_strategy_getter(None) except AttributeError as e: if is_attribute_error_about(e, 'set_attention_slice_wrangler'): print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") else: raise def is_attribute_error_about(error: AttributeError, attribute: str): if hasattr(error, 'name'): # Python 3.10 return error.name == attribute else: # Python 3.9 return attribute in str(error) def get_mem_free_total(device): #only on cuda if not torch.cuda.is_available(): return None stats = torch.cuda.memory_stats(device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current'] mem_free_cuda, _ = torch.cuda.mem_get_info(device) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch return mem_free_total class InvokeAIDiffusersCrossAttention(diffusers.models.attention.CrossAttention, InvokeAICrossAttentionMixin): def __init__(self, **kwargs): super().__init__(**kwargs) InvokeAICrossAttentionMixin.__init__(self) def _attention(self, query, key, value, attention_mask=None): #default_result = super()._attention(query, key, value) if attention_mask is not None: print(f"{type(self).__name__} ignoring passed-in attention_mask") attention_result = self.get_invokeai_attention_mem_efficient(query, key, value) hidden_states = self.reshape_batch_dim_to_heads(attention_result) return hidden_states ## 🧨diffusers implementation follows """ # base implementation class CrossAttnProcessor: def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) query = attn.head_to_batch_dim(query) encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states """ from dataclasses import field, dataclass import torch from diffusers.models.cross_attention import CrossAttention, CrossAttnProcessor, SlicedAttnProcessor @dataclass class SwapCrossAttnContext: modified_text_embeddings: torch.Tensor index_map: torch.Tensor # maps from original prompt token indices to the equivalent tokens in the modified prompt mask: torch.Tensor # in the target space of the index_map cross_attention_types_to_do: list[CrossAttentionType] = field(default_factory=list) def __int__(self, cac_types_to_do: [CrossAttentionType], modified_text_embeddings: torch.Tensor, index_map: torch.Tensor, mask: torch.Tensor): self.cross_attention_types_to_do = cac_types_to_do self.modified_text_embeddings = modified_text_embeddings self.index_map = index_map self.mask = mask def wants_cross_attention_control(self, attn_type: CrossAttentionType) -> bool: return attn_type in self.cross_attention_types_to_do @classmethod def make_mask_and_index_map(cls, edit_opcodes: list[tuple[str, int, int, int, int]], max_length: int) \ -> tuple[torch.Tensor, torch.Tensor]: # mask=1 means use original prompt attention, mask=0 means use modified prompt attention mask = torch.zeros(max_length) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in edit_opcodes: if b0 < max_length: if name == "equal": # these tokens remain the same as in the original prompt indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 return mask, indices class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor): # TODO: dynamically pick slice size based on memory conditions def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, # kwargs swap_cross_attn_context: SwapCrossAttnContext=None): attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS # if cross-attention control is not in play, just call through to the base implementation. if attention_type is CrossAttentionType.SELF or \ swap_cross_attn_context is None or \ not swap_cross_attn_context.wants_cross_attention_control(attention_type): #print(f"SwapCrossAttnContext for {attention_type} not active - passing request to superclass") return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask) #else: # print(f"SwapCrossAttnContext for {attention_type} active") batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) query = attn.to_q(hidden_states) dim = query.shape[-1] query = attn.head_to_batch_dim(query) original_text_embeddings = encoder_hidden_states modified_text_embeddings = swap_cross_attn_context.modified_text_embeddings original_text_key = attn.to_k(original_text_embeddings) modified_text_key = attn.to_k(modified_text_embeddings) original_value = attn.to_v(original_text_embeddings) modified_value = attn.to_v(modified_text_embeddings) original_text_key = attn.head_to_batch_dim(original_text_key) modified_text_key = attn.head_to_batch_dim(modified_text_key) original_value = attn.head_to_batch_dim(original_value) modified_value = attn.head_to_batch_dim(modified_value) # compute slices and prepare output tensor batch_size_attention = query.shape[0] hidden_states = torch.zeros( (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype ) # do slices for i in range(max(1,hidden_states.shape[0] // self.slice_size)): start_idx = i * self.slice_size end_idx = (i + 1) * self.slice_size query_slice = query[start_idx:end_idx] original_key_slice = original_text_key[start_idx:end_idx] modified_key_slice = modified_text_key[start_idx:end_idx] attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None original_attn_slice = attn.get_attention_scores(query_slice, original_key_slice, attn_mask_slice) modified_attn_slice = attn.get_attention_scores(query_slice, modified_key_slice, attn_mask_slice) # because the prompt modifications may result in token sequences shifted forwards or backwards, # the original attention probabilities must be remapped to account for token index changes in the # modified prompt remapped_original_attn_slice = torch.index_select(original_attn_slice, -1, swap_cross_attn_context.index_map) # only some tokens taken from the original attention probabilities. this is controlled by the mask. mask = swap_cross_attn_context.mask inverse_mask = 1 - mask attn_slice = \ remapped_original_attn_slice * mask + \ modified_attn_slice * inverse_mask del remapped_original_attn_slice, modified_attn_slice attn_slice = torch.bmm(attn_slice, modified_value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice # done hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) return hidden_states class SwapCrossAttnProcessor(SlicedSwapCrossAttnProcesser): def __init__(self): super(SwapCrossAttnProcessor, self).__init__(slice_size=int(1e9)) # massive slice size = don't slice