This commit is contained in:
Sergey Borisov 2023-08-28 18:36:27 +03:00
parent cd548f73fd
commit 2bf747caf6

View File

@ -764,6 +764,7 @@ diffusers.models.controlnet.ControlNetModel = ControlNetModel
try: try:
import xformers import xformers
xformers_available = True xformers_available = True
except: except:
xformers_available = False xformers_available = False
@ -772,27 +773,28 @@ except:
if xformers_available: if xformers_available:
# TODO: remove when fixed in diffusers # TODO: remove when fixed in diffusers
_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention _xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
def new_memory_efficient_attention( def new_memory_efficient_attention(
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
attn_bias = None, attn_bias=None,
p: float = 0.0, p: float = 0.0,
scale: Optional[float] = None, scale: Optional[float] = None,
*, *,
op = None, op=None,
): ):
# diffusers not align shape to 8, which is required by xformers # diffusers not align shape to 8, which is required by xformers
if attn_bias is not None and type(attn_bias) is torch.Tensor: if attn_bias is not None and type(attn_bias) is torch.Tensor:
orig_size = attn_bias.shape[-1] orig_size = attn_bias.shape[-1]
new_size = ((orig_size + 7) // 8) * 8 new_size = ((orig_size + 7) // 8) * 8
aligned_attn_bias = torch.zeros( aligned_attn_bias = torch.zeros(
(attn_bias.shape[0], attn_bias.shape[1], new_size), (attn_bias.shape[0], attn_bias.shape[1], new_size),
device=attn_bias.device, device=attn_bias.device,
dtype=attn_bias.dtype, dtype=attn_bias.dtype,
) )
aligned_attn_bias[:,:,:orig_size] = attn_bias aligned_attn_bias[:, :, :orig_size] = attn_bias
attn_bias = aligned_attn_bias[:,:,:orig_size] attn_bias = aligned_attn_bias[:, :, :orig_size]
return _xformers_memory_efficient_attention( return _xformers_memory_efficient_attention(
query=query, query=query,