diff --git a/invokeai/backend/util/hotfixes.py b/invokeai/backend/util/hotfixes.py index 3d7f278f86..cf97d494d7 100644 --- a/invokeai/backend/util/hotfixes.py +++ b/invokeai/backend/util/hotfixes.py @@ -761,3 +761,47 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin): diffusers.ControlNetModel = ControlNetModel diffusers.models.controlnet.ControlNetModel = ControlNetModel + +try: + import xformers + xformers_available = True +except: + xformers_available = False + + +if xformers_available: + # TODO: remove when fixed in diffusers + _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention + def new_memory_efficient_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias = None, + p: float = 0.0, + scale: Optional[float] = None, + *, + op = None, + ): + # diffusers not align shape to 8, which is required by xformers + if attn_bias is not None and type(attn_bias) is torch.Tensor: + orig_size = attn_bias.shape[-1] + new_size = ((orig_size + 7) // 8) * 8 + aligned_attn_bias = torch.zeros( + (attn_bias.shape[0], attn_bias.shape[1], new_size), + device=attn_bias.device, + dtype=attn_bias.dtype, + ) + aligned_attn_bias[:,:,:orig_size] = attn_bias + attn_bias = aligned_attn_bias[:,:,:orig_size] + + return _xformers_memory_efficient_attention( + query=query, + key=key, + value=value, + attn_bias=attn_bias, + p=p, + scale=scale, + op=op, + ) + + xformers.ops.memory_efficient_attention = new_memory_efficient_attention