From ff42027a002a2fad81b04d8538053ba1c7fcd3c0 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Sun, 4 Dec 2022 16:07:54 +0100 Subject: [PATCH] add cross-attention control support to diffusers (fails on MPS) For unknown reasons MPS produces garbage output with .swap(). Use --always_use_cpu arg to invoke.py for now to test this code on MPS. --- ldm/invoke/generator/diffusers_pipeline.py | 8 + .../diffusion/cross_attention_control.py | 184 ++++++++++++++++-- ldm/modules/attention.py | 2 + 3 files changed, 182 insertions(+), 12 deletions(-) diff --git a/ldm/invoke/generator/diffusers_pipeline.py b/ldm/invoke/generator/diffusers_pipeline.py index 989d43546e..174a57fd85 100644 --- a/ldm/invoke/generator/diffusers_pipeline.py +++ b/ldm/invoke/generator/diffusers_pipeline.py @@ -9,6 +9,14 @@ import PIL.Image import einops import torch import torchvision.transforms as T + +from ldm.models.diffusion.cross_attention_control import InvokeAICrossAttention + +from diffusers.models import attention +# monkeypatch diffusers CrossAttention 🙈 +# this is to make prompt2prompt and (future) attention maps work +attention.CrossAttention = InvokeAICrossAttention + from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index ec7c3c215c..ec4c344716 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -1,8 +1,11 @@ import enum -import warnings -from typing import Optional +import math +from typing import Optional, Callable +import psutil import torch +import diffusers +from torch import nn # adapted from bloc97's CrossAttentionControl colab @@ -66,8 +69,12 @@ class Context: def register_cross_attention_modules(self, model): for name,module in get_attention_modules(model, CrossAttentionType.SELF): + if name in self.self_cross_attention_module_identifiers: + assert False, f"name {name} cannot appear more than once" self.self_cross_attention_module_identifiers.append(name) for name,module in get_attention_modules(model, CrossAttentionType.TOKENS): + if name in self.tokens_cross_attention_module_identifiers: + assert False, f"name {name} cannot appear more than once" self.tokens_cross_attention_module_identifiers.append(name) def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): @@ -189,7 +196,7 @@ def setup_cross_attention_control(model, context: Context): # mask=1 means use base prompt attention, mask=0 means use edited prompt attention mask = torch.zeros(max_length) indices_target = torch.arange(max_length, dtype=torch.long) - indices = torch.zeros(max_length, dtype=torch.long) + indices = torch.arange(max_length, dtype=torch.long) for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: if b0 < max_length: if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): @@ -204,9 +211,22 @@ def setup_cross_attention_control(model, context: Context): def get_attention_modules(model, which: CrossAttentionType): + # cross_attention_class: type = ldm.modules.attention.CrossAttention + cross_attention_class: type = InvokeAICrossAttention which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" - return [(name,module) for name, module in model.named_modules() if - type(module).__name__ == "CrossAttention" and which_attn in name] + attention_module_tuples = [(name,module) for name, module in model.named_modules() if + isinstance(module, cross_attention_class) and which_attn in name] + cross_attention_modules_in_model_count = len(attention_module_tuples) + expected_count = 16 + if cross_attention_modules_in_model_count != expected_count: + # non-fatal error but .swap() won't work. + print(f"Error! CrossAttentionControl found an unexpected number of InvokeAICrossAttention modules in the model " + + f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + + f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + + f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " + + f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " + + f"work properly until it is fixed.") + return attention_module_tuples def inject_attention_function(unet, context: Context): @@ -246,8 +266,7 @@ def inject_attention_function(unet, context: Context): return attention_slice - cross_attention_modules = [(name, module) for (name, module) in unet.named_modules() - if type(module).__name__ == "CrossAttention"] + cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF) for identifier, module in cross_attention_modules: module.identifier = identifier try: @@ -257,22 +276,21 @@ def inject_attention_function(unet, context: Context): ) except AttributeError as e: if is_attribute_error_about(e, 'set_attention_slice_wrangler'): - warnings.warn(f"TODO: implement for {type(module)}") # TODO + print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO else: raise def remove_attention_function(unet): - cross_attention_modules = [module for (_, module) in unet.named_modules() - if type(module).__name__ == "CrossAttention"] - for module in cross_attention_modules: + cross_attention_modules = get_attention_modules(unet, CrossAttentionType.TOKENS) + get_attention_modules(unet, CrossAttentionType.SELF) + for identifier, module in cross_attention_modules: try: # clear wrangler callback module.set_attention_slice_wrangler(None) module.set_slicing_strategy_getter(None) except AttributeError as e: if is_attribute_error_about(e, 'set_attention_slice_wrangler'): - warnings.warn(f"TODO: implement for {type(module)}") # TODO + print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") else: raise @@ -282,3 +300,145 @@ def is_attribute_error_about(error: AttributeError, attribute: str): return error.name == attribute else: # Python 3.9 return attribute in str(error) + + + +def get_mem_free_total(device): + #only on cuda + if not torch.cuda.is_available(): + return None + stats = torch.cuda.memory_stats(device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(device) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total + +class InvokeAICrossAttention(diffusers.models.attention.CrossAttention): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) + + self.attention_slice_wrangler = None + self.slicing_strategy_getter = None + + def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]): + ''' + Set custom attention calculator to be called when attention is calculated + :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), + which returns either the suggested_attention_slice or an adjusted equivalent. + `module` is the current CrossAttention module for which the callback is being invoked. + `suggested_attention_slice` is the default-calculated attention slice + `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. + If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. + + Pass None to use the default attention calculation. + :return: + ''' + self.attention_slice_wrangler = wrangler + + def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]): + self.slicing_strategy_getter = getter + + def _attention(self, query, key, value): + #default_result = super()._attention(query, key, value) + damian_result = self.get_attention_mem_efficient(query, key, value) + + hidden_states = self.reshape_batch_dim_to_heads(damian_result) + return hidden_states + + def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): + # calculate attention scores + #attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + if dim is not None: + print(f"sliced dim {dim}, offset {offset}, slice_size {slice_size}") + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + # calculate attention slice by taking the best scores for each latent pixel + default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) + attention_slice_wrangler = self.attention_slice_wrangler + if attention_slice_wrangler is not None: + attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) + else: + attention_slice = default_attention_slice + + #return torch.einsum('b i j, b j d -> b i d', attention_slice, v) + hidden_states = torch.bmm(attention_slice, value) + return hidden_states + + + def einsum_op_slice_dim0(self, q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], slice_size): + end = i + slice_size + r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) + return r + + def einsum_op_slice_dim1(self, q, k, v, slice_size): + r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) + return r + + def einsum_op_mps_v1(self, q, k, v): + if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + return self.einsum_lowest_level(q, k, v, None, None, None) + else: + slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) + return self.einsum_op_slice_dim1(q, k, v, slice_size) + + def einsum_op_mps_v2(self, q, k, v): + if self.mem_total_gb > 8 and q.shape[1] <= 4096: + return self.einsum_lowest_level(q, k, v, None, None, None) + else: + return self.einsum_op_slice_dim0(q, k, v, 1) + + def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): + size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) + if size_mb <= max_tensor_mb: + return self.einsum_lowest_level(q, k, v, None, None, None) + div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() + if div <= q.shape[0]: + return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) + return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) + + def einsum_op_cuda(self, q, k, v): + # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) + slicing_strategy_getter = self.slicing_strategy_getter + if slicing_strategy_getter is not None: + (dim, slice_size) = slicing_strategy_getter(self) + if dim is not None: + # print("using saved slicing strategy with dim", dim, "slice size", slice_size) + if dim == 0: + return self.einsum_op_slice_dim0(q, k, v, slice_size) + elif dim == 1: + return self.einsum_op_slice_dim1(q, k, v, slice_size) + + # fallback for when there is no saved strategy, or saved strategy does not slice + mem_free_total = get_mem_free_total(q.device) + # Divide factor of safety as there's copying and fragmentation + return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + + + def get_attention_mem_efficient(self, q, k, v): + if q.device.type == 'cuda': + #print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) + return self.einsum_op_cuda(q, k, v) + + if q.device.type == 'mps' or q.device.type == 'cpu': + if self.mem_total_gb >= 32: + return self.einsum_op_mps_v1(q, k, v) + return self.einsum_op_mps_v2(q, k, v) + + # Smaller slices are faster due to L2/L3/SLC caches. + # Tested on i7 with 8MB L3 cache. + return self.einsum_op_tensor_mem(q, k, v, 32) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 94bb8a2916..94922270a4 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -165,7 +165,9 @@ def get_mem_free_total(device): class CrossAttention(nn.Module): + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): + print(f"Warning! ldm.modules.attention.CrossAttention is no longer being maintained. Please use InvokeAICrossAttention instead.") super().__init__() inner_dim = dim_head * heads context_dim = default(context_dim, query_dim)