mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add monkeypatch for xformers to align unaligned attention_mask
This commit is contained in:
parent
ef3bf2803f
commit
b65c9ad612
@ -761,3 +761,47 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
|
||||
diffusers.ControlNetModel = ControlNetModel
|
||||
diffusers.models.controlnet.ControlNetModel = ControlNetModel
|
||||
|
||||
try:
|
||||
import xformers
|
||||
xformers_available = True
|
||||
except:
|
||||
xformers_available = False
|
||||
|
||||
|
||||
if xformers_available:
|
||||
# TODO: remove when fixed in diffusers
|
||||
_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention
|
||||
def new_memory_efficient_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_bias = None,
|
||||
p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
*,
|
||||
op = None,
|
||||
):
|
||||
# diffusers not align shape to 8, which is required by xformers
|
||||
if attn_bias is not None and type(attn_bias) is torch.Tensor:
|
||||
orig_size = attn_bias.shape[-1]
|
||||
new_size = ((orig_size + 7) // 8) * 8
|
||||
aligned_attn_bias = torch.zeros(
|
||||
(attn_bias.shape[0], attn_bias.shape[1], new_size),
|
||||
device=attn_bias.device,
|
||||
dtype=attn_bias.dtype,
|
||||
)
|
||||
aligned_attn_bias[:,:,:orig_size] = attn_bias
|
||||
attn_bias = aligned_attn_bias[:,:,:orig_size]
|
||||
|
||||
return _xformers_memory_efficient_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
attn_bias=attn_bias,
|
||||
p=p,
|
||||
scale=scale,
|
||||
op=op,
|
||||
)
|
||||
|
||||
xformers.ops.memory_efficient_attention = new_memory_efficient_attention
|
||||
|
Loading…
Reference in New Issue
Block a user