diff --git a/ldm/models/diffusion/cross_attention_control.py b/ldm/models/diffusion/cross_attention_control.py index 9c8c597869..ff90a24856 100644 --- a/ldm/models/diffusion/cross_attention_control.py +++ b/ldm/models/diffusion/cross_attention_control.py @@ -1,10 +1,13 @@ -from enum import Enum +import enum +from typing import Optional import torch # adapted from bloc97's CrossAttentionControl colab # https://github.com/bloc97/CrossAttentionControl + + class CrossAttentionControl: class Arguments: @@ -27,7 +30,14 @@ class CrossAttentionControl: print('warning: cross-attention control options are not working properly for >1 edit') self.edit_options = non_none_edit_options[0] + class Context: + + class Action(enum.Enum): + NONE = 0 + SAVE = 1, + APPLY = 2 + def __init__(self, arguments: 'CrossAttentionControl.Arguments', step_count: int): """ :param arguments: Arguments for the cross-attention control process @@ -36,14 +46,124 @@ class CrossAttentionControl: self.arguments = arguments self.step_count = step_count + self.self_cross_attention_module_identifiers = [] + self.tokens_cross_attention_module_identifiers = [] + + self.saved_cross_attention_maps = {} + + self.clear_requests(cleanup=True) + + def register_cross_attention_modules(self, model): + for name,module in CrossAttentionControl.get_attention_modules(model, + CrossAttentionControl.CrossAttentionType.SELF): + self.self_cross_attention_module_identifiers.append(name) + for name,module in CrossAttentionControl.get_attention_modules(model, + CrossAttentionControl.CrossAttentionType.TOKENS): + self.tokens_cross_attention_module_identifiers.append(name) + + def request_save_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): + if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: + self.self_cross_attention_action = CrossAttentionControl.Context.Action.SAVE + else: + self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.SAVE + + def request_apply_saved_attention_maps(self, cross_attention_type: 'CrossAttentionControl.CrossAttentionType'): + if cross_attention_type == CrossAttentionControl.CrossAttentionType.SELF: + self.self_cross_attention_action = CrossAttentionControl.Context.Action.APPLY + else: + self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.APPLY + + def is_tokens_cross_attention(self, module_identifier) -> bool: + return module_identifier in self.tokens_cross_attention_module_identifiers + + def get_should_save_maps(self, module_identifier: str) -> bool: + if module_identifier in self.self_cross_attention_module_identifiers: + return self.self_cross_attention_action == CrossAttentionControl.Context.Action.SAVE + elif module_identifier in self.tokens_cross_attention_module_identifiers: + return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.SAVE + return False + + def get_should_apply_saved_maps(self, module_identifier: str) -> bool: + if module_identifier in self.self_cross_attention_module_identifiers: + return self.self_cross_attention_action == CrossAttentionControl.Context.Action.APPLY + elif module_identifier in self.tokens_cross_attention_module_identifiers: + return self.tokens_cross_attention_action == CrossAttentionControl.Context.Action.APPLY + return False + + def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ + -> list['CrossAttentionControl.CrossAttentionType']: + """ + Should cross-attention control be applied on the given step? + :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. + :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. + """ + if percent_through is None: + return [CrossAttentionControl.CrossAttentionType.SELF, CrossAttentionControl.CrossAttentionType.TOKENS] + + opts = self.arguments.edit_options + to_control = [] + if opts['s_start'] <= percent_through and percent_through < opts['s_end']: + to_control.append(CrossAttentionControl.CrossAttentionType.SELF) + if opts['t_start'] <= percent_through and percent_through < opts['t_end']: + to_control.append(CrossAttentionControl.CrossAttentionType.TOKENS) + return to_control + + def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, + slice_size: Optional[int]): + if identifier not in self.saved_cross_attention_maps: + self.saved_cross_attention_maps[identifier] = { + 'dim': dim, + 'slice_size': slice_size, + 'slices': {offset or 0: slice} + } + else: + self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice + + def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int): + saved_attention_dict = self.saved_cross_attention_maps[identifier] + if requested_dim is None: + if saved_attention_dict['dim'] is not None: + raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") + return saved_attention_dict['slices'][0] + + if saved_attention_dict['dim'] == requested_dim: + if slice_size != saved_attention_dict['slice_size']: + raise RuntimeError( + f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") + return saved_attention_dict['slices'][requested_offset] + + if saved_attention_dict['dim'] == None: + whole_saved_attention = saved_attention_dict['slices'][0] + if requested_dim == 0: + return whole_saved_attention[requested_offset:requested_offset + slice_size] + elif requested_dim == 1: + return whole_saved_attention[:, requested_offset:requested_offset + slice_size] + + raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") + + def get_slicing_strategy(self, identifier: str) -> Optional[tuple[int, int]]: + saved_attention = self.saved_cross_attention_maps.get(identifier, None) + if saved_attention is None: + return None, None + return saved_attention['dim'], saved_attention['slice_size'] + + def clear_requests(self, cleanup=True): + self.tokens_cross_attention_action = CrossAttentionControl.Context.Action.NONE + self.self_cross_attention_action = CrossAttentionControl.Context.Action.NONE + if cleanup: + self.saved_cross_attention_maps = {} + + def offload_saved_attention_slices_to_cpu(self): + for key, map_dict in self.saved_cross_attention_maps.items(): + for offset, slice in map_dict['slices'].items(): + map_dict[offset] = slice.to('cpu') + @classmethod def remove_cross_attention_control(cls, model): cls.remove_attention_function(model) @classmethod - def setup_cross_attention_control(cls, model, - cross_attention_control_args: Arguments - ): + def setup_cross_attention_control(cls, model, context: Context): """ Inject attention parameters and functions into the passed in model to enable cross attention editing. @@ -53,7 +173,7 @@ class CrossAttentionControl: """ # adapted from init_attention_edit - device = cross_attention_control_args.edited_conditioning.device + device = context.arguments.edited_conditioning.device # urgh. should this be hardcoded? max_length = 77 @@ -61,141 +181,82 @@ class CrossAttentionControl: mask = torch.zeros(max_length) indices_target = torch.arange(max_length, dtype=torch.long) indices = torch.zeros(max_length, dtype=torch.long) - for name, a0, a1, b0, b1 in cross_attention_control_args.edit_opcodes: + for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: if b0 < max_length: if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): # these tokens have not been edited indices[b0:b1] = indices_target[a0:a1] mask[b0:b1] = 1 - cls.inject_attention_function(model) - - for m in cls.get_attention_modules(model, cls.CrossAttentionType.SELF): - m.last_attn_slice_mask = None - m.last_attn_slice_indices = None - - for m in cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS): - m.last_attn_slice_mask = mask.to(device) - m.last_attn_slice_indices = indices.to(device) + context.register_cross_attention_modules(model) + context.cross_attention_mask = mask.to(device) + context.cross_attention_index_map = indices.to(device) + cls.inject_attention_function(model, context) - class CrossAttentionType(Enum): + class CrossAttentionType(enum.Enum): SELF = 1 TOKENS = 2 - @classmethod - def get_active_cross_attention_control_types_for_step(cls, context: 'CrossAttentionControl.Context', percent_through:float=None)\ - -> list['CrossAttentionControl.CrossAttentionType']: - """ - Should cross-attention control be applied on the given step? - :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. - :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. - """ - if percent_through is None: - return [cls.CrossAttentionType.SELF, cls.CrossAttentionType.TOKENS] - - opts = context.arguments.edit_options - to_control = [] - if opts['s_start'] <= percent_through and percent_through < opts['s_end']: - to_control.append(cls.CrossAttentionType.SELF) - if opts['t_start'] <= percent_through and percent_through < opts['t_end']: - to_control.append(cls.CrossAttentionType.TOKENS) - return to_control - - @classmethod def get_attention_modules(cls, model, which: CrossAttentionType): which_attn = "attn1" if which is cls.CrossAttentionType.SELF else "attn2" - return [module for name, module in model.named_modules() if + return [(name,module) for name, module in model.named_modules() if type(module).__name__ == "CrossAttention" and which_attn in name] - @classmethod - def clear_requests(cls, model, clear_attn_slice=True): - self_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.SELF) - tokens_attention_modules = cls.get_attention_modules(model, cls.CrossAttentionType.TOKENS) - for m in self_attention_modules+tokens_attention_modules: - m.save_last_attn_slice = False - m.use_last_attn_slice = False - if clear_attn_slice: - m.last_attn_slice = None @classmethod - def request_save_attention_maps(cls, model, cross_attention_type: CrossAttentionType): - modules = cls.get_attention_modules(model, cross_attention_type) - for m in modules: - # clear out the saved slice in case the outermost dim changes - m.last_attn_slice = None - m.save_last_attn_slice = True - - @classmethod - def request_apply_saved_attention_maps(cls, model, cross_attention_type: CrossAttentionType): - modules = cls.get_attention_modules(model, cross_attention_type) - for m in modules: - m.use_last_attn_slice = True - - - - @classmethod - def inject_attention_function(cls, unet): + def inject_attention_function(cls, unet, context: 'CrossAttentionControl.Context'): # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 - def attention_slice_wrangler(self, attention_scores, suggested_attention_slice, dim, offset, slice_size): + def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): - #print("in wrangler with suggested_attention_slice shape", suggested_attention_slice.shape, "dim", dim) + #memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() - attn_slice = suggested_attention_slice - if dim is not None: - start = offset - end = start+slice_size - #print(f"in wrangler, sliced dim {dim} {start}-{end}, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") - #else: - # print(f"in wrangler, whole, use_last_attn_slice is {self.use_last_attn_slice}, save_last_attn_slice is {self.save_last_attn_slice}") + attention_slice = suggested_attention_slice - if self.use_last_attn_slice: - if dim is None: - last_attn_slice = self.last_attn_slice - # print("took whole slice of shape", attn_slice.shape, "from complete shape", self.last_attn_slice.shape) + if context.get_should_save_maps(module.identifier): + #print(module.identifier, "saving suggested_attention_slice of shape", + # suggested_attention_slice.shape, "dim", dim, "offset", offset) + slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice + context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size) + elif context.get_should_apply_saved_maps(module.identifier): + #print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) + saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) + + # slice may have been offloaded to CPU + saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) + + if context.is_tokens_cross_attention(module.identifier): + index_map = context.cross_attention_index_map + remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) + this_attention_slice = suggested_attention_slice + + mask = context.cross_attention_mask + saved_mask = mask + this_mask = 1 - mask + attention_slice = remapped_saved_attention_slice * saved_mask + \ + this_attention_slice * this_mask else: - last_attn_slice = self.last_attn_slice[offset] - - if self.last_attn_slice_mask is None: # just use everything - attn_slice = last_attn_slice - else: - last_attn_slice_mask = self.last_attn_slice_mask - remapped_last_attn_slice = torch.index_select(last_attn_slice, -1, self.last_attn_slice_indices) + attention_slice = saved_attention_slice - this_attn_slice = attn_slice - this_attn_slice_mask = 1 - last_attn_slice_mask - attn_slice = this_attn_slice * this_attn_slice_mask + \ - remapped_last_attn_slice * last_attn_slice_mask - - if self.save_last_attn_slice: - if dim is None: - self.last_attn_slice = attn_slice - else: - if self.last_attn_slice is None: - self.last_attn_slice = { offset: attn_slice } - else: - self.last_attn_slice[offset] = attn_slice - - return attn_slice + return attention_slice for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention": - module.last_attn_slice = None - module.last_attn_slice_indices = None - module.last_attn_slice_mask = None - module.use_last_attn_weights = False - module.use_last_attn_slice = False - module.save_last_attn_slice = False + module.identifier = name module.set_attention_slice_wrangler(attention_slice_wrangler) + module.set_slicing_strategy_getter(lambda module, module_identifier=name: \ + context.get_slicing_strategy(module_identifier)) @classmethod def remove_attention_function(cls, unet): + # clear wrangler callback for name, module in unet.named_modules(): module_name = type(module).__name__ if module_name == "CrossAttention": module.set_attention_slice_wrangler(None) + module.set_slicing_strategy_getter(None) diff --git a/ldm/models/diffusion/shared_invokeai_diffusion.py b/ldm/models/diffusion/shared_invokeai_diffusion.py index 5a9cc3eb74..1b181ba388 100644 --- a/ldm/models/diffusion/shared_invokeai_diffusion.py +++ b/ldm/models/diffusion/shared_invokeai_diffusion.py @@ -1,9 +1,11 @@ +import traceback from math import ceil from typing import Callable, Optional, Union import torch from ldm.models.diffusion.cross_attention_control import CrossAttentionControl +from ldm.modules.attention import get_mem_free_total class InvokeAIDiffuserComponent: @@ -34,7 +36,7 @@ class InvokeAIDiffuserComponent: """ self.model = model self.model_forward_callback = model_forward_callback - + self.cross_attention_control_context = None def setup_cross_attention_control(self, conditioning: ExtraConditioningInfo, step_count: int): self.conditioning = conditioning @@ -42,11 +44,7 @@ class InvokeAIDiffuserComponent: arguments=self.conditioning.cross_attention_control_args, step_count=step_count ) - CrossAttentionControl.setup_cross_attention_control(self.model, - cross_attention_control_args=self.conditioning.cross_attention_control_args - ) - #todo: refactor edited_conditioning, edit_opcodes, edit_options into a struct - #todo: apply edit_options using step_count + CrossAttentionControl.setup_cross_attention_control(self.model, self.cross_attention_control_context) def remove_cross_attention_control(self): self.conditioning = None @@ -54,6 +52,7 @@ class InvokeAIDiffuserComponent: CrossAttentionControl.remove_cross_attention_control(self.model) + def do_diffusion_step(self, x: torch.Tensor, sigma: torch.Tensor, unconditioning: Union[torch.Tensor,dict], conditioning: Union[torch.Tensor,dict], @@ -70,12 +69,12 @@ class InvokeAIDiffuserComponent: :return: the new latents after applying the model to x using unscaled unconditioning and CFG-scaled conditioning. """ - CrossAttentionControl.clear_requests(self.model) cross_attention_control_types_to_do = [] + context: CrossAttentionControl.Context = self.cross_attention_control_context if self.cross_attention_control_context is not None: percent_through = self.estimate_percent_through(step_index, sigma) - cross_attention_control_types_to_do = CrossAttentionControl.get_active_cross_attention_control_types_for_step(self.cross_attention_control_context, percent_through) + cross_attention_control_types_to_do = context.get_active_cross_attention_control_types_for_step(percent_through) wants_cross_attention_control = (len(cross_attention_control_types_to_do) > 0) wants_hybrid_conditioning = isinstance(conditioning, dict) @@ -124,7 +123,7 @@ class InvokeAIDiffuserComponent: return unconditioned_next_x, conditioned_next_x - def apply_cross_attention_controlled_conditioning(self, x, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): + def apply_cross_attention_controlled_conditioning(self, x:torch.Tensor, sigma, unconditioning, conditioning, cross_attention_control_types_to_do): # print('pct', percent_through, ': doing cross attention control on', cross_attention_control_types_to_do) # slower non-batched path (20% slower on mac MPS) # We are only interested in using attention maps for conditioned_next_x, but batching them with generation of @@ -134,31 +133,29 @@ class InvokeAIDiffuserComponent: # representing batched uncond + cond, but then when it comes to applying the saved attention, the # wrangler gets an attention tensor which only has shape[0]=8, representing just self.edited_conditionings.) # todo: give CrossAttentionControl's `wrangler` function more info so it can work with a batched call as well. + context:CrossAttentionControl.Context = self.cross_attention_control_context try: unconditioned_next_x = self.model_forward_callback(x, sigma, unconditioning) # process x using the original prompt, saving the attention maps - for type in cross_attention_control_types_to_do: - CrossAttentionControl.request_save_attention_maps(self.model, type) + #print("saving attention maps for", cross_attention_control_types_to_do) + for ca_type in cross_attention_control_types_to_do: + context.request_save_attention_maps(ca_type) _ = self.model_forward_callback(x, sigma, conditioning) - CrossAttentionControl.clear_requests(self.model, clear_attn_slice=False) + context.clear_requests(cleanup=False) # process x again, using the saved attention maps to control where self.edited_conditioning will be applied - for type in cross_attention_control_types_to_do: - CrossAttentionControl.request_apply_saved_attention_maps(self.model, type) + #print("applying saved attention maps for", cross_attention_control_types_to_do) + for ca_type in cross_attention_control_types_to_do: + context.request_apply_saved_attention_maps(ca_type) edited_conditioning = self.conditioning.cross_attention_control_args.edited_conditioning conditioned_next_x = self.model_forward_callback(x, sigma, edited_conditioning) - CrossAttentionControl.clear_requests(self.model) + finally: + context.clear_requests(cleanup=True) - return unconditioned_next_x, conditioned_next_x - - except RuntimeError: - # make sure we clean out the attention slices we're storing on the model - # TODO don't store things on the model - CrossAttentionControl.clear_requests(self.model) - raise + return unconditioned_next_x, conditioned_next_x def estimate_percent_through(self, step_index, sigma): if step_index is not None and self.cross_attention_control_context is not None: diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 4c36fa8a6c..05f6183029 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,6 +1,6 @@ from inspect import isfunction import math -from typing import Callable +from typing import Callable, Optional import torch import torch.nn.functional as F @@ -151,6 +151,17 @@ class SpatialSelfAttention(nn.Module): return x+h_ +def get_mem_free_total(device): + #only on cuda + if not torch.cuda.is_available(): + return None + stats = torch.cuda.memory_stats(device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(device) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + return mem_free_total class CrossAttention(nn.Module): @@ -173,31 +184,43 @@ class CrossAttention(nn.Module): self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) + self.cached_mem_free_total = None self.attention_slice_wrangler = None + self.slicing_strategy_getter = None - def set_attention_slice_wrangler(self, wrangler:Callable[[nn.Module, torch.Tensor, torch.Tensor, int, int, int], torch.Tensor]): + def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]): ''' Set custom attention calculator to be called when attention is calculated - :param wrangler: Callback, with args (self, attention_scores, suggested_attention_slice, dim, offset, slice_size), + :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), which returns either the suggested_attention_slice or an adjusted equivalent. - self is the current CrossAttention module for which the callback is being invoked. - attention_scores are the scores for attention - suggested_attention_slice is a softmax(dim=-1) over attention_scores - dim is -1 if the call is non-sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. - If dim is >= 0, offset and slice_size specify the slice start and length. + `module` is the current CrossAttention module for which the callback is being invoked. + `suggested_attention_slice` is the default-calculated attention slice + `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. + If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. Pass None to use the default attention calculation. :return: ''' self.attention_slice_wrangler = wrangler + def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]): + self.slicing_strategy_getter = getter + + def cache_free_memory_count(self, device): + self.cached_mem_free_total = get_mem_free_total(device) + print("free cuda memory: ", self.cached_mem_free_total) + + def clear_cached_free_memory_count(self): + self.cached_mem_free_total = None + def einsum_lowest_level(self, q, k, v, dim, offset, slice_size): # calculate attention scores attention_scores = einsum('b i d, b j d -> b i j', q, k) - # calculate attenion slice by taking the best scores for each latent pixel + # calculate attention slice by taking the best scores for each latent pixel default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) - if self.attention_slice_wrangler is not None: - attention_slice = self.attention_slice_wrangler(self, attention_scores, default_attention_slice, dim, offset, slice_size) + attention_slice_wrangler = self.attention_slice_wrangler + if attention_slice_wrangler is not None: + attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) else: attention_slice = default_attention_slice @@ -240,17 +263,27 @@ class CrossAttention(nn.Module): return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) def einsum_op_cuda(self, q, k, v): - stats = torch.cuda.memory_stats(q.device) - mem_active = stats['active_bytes.all.current'] - mem_reserved = stats['reserved_bytes.all.current'] - mem_free_cuda, _ = torch.cuda.mem_get_info(q.device) - mem_free_torch = mem_reserved - mem_active - mem_free_total = mem_free_cuda + mem_free_torch + # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) + slicing_strategy_getter = self.slicing_strategy_getter + if slicing_strategy_getter is not None: + (dim, slice_size) = slicing_strategy_getter(self) + if dim is not None: + # print("using saved slicing strategy with dim", dim, "slice size", slice_size) + if dim == 0: + return self.einsum_op_slice_dim0(q, k, v, slice_size) + elif dim == 1: + return self.einsum_op_slice_dim1(q, k, v, slice_size) + + # fallback for when there is no saved strategy, or saved strategy does not slice + mem_free_total = self.cached_mem_free_total or get_mem_free_total(q.device) # Divide factor of safety as there's copying and fragmentation return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + def get_attention_mem_efficient(self, q, k, v): if q.device.type == 'cuda': + torch.cuda.empty_cache() + #print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) return self.einsum_op_cuda(q, k, v) if q.device.type == 'mps':