disable neonpixel optimizations on M1 hardware (#414)

* disable neonpixel optimizations on M1 hardware

* fix typo that was causing random noise images on m1
This commit is contained in:
Lincoln Stein 2022-09-07 13:28:11 -04:00 committed by GitHub
parent 7670ecc63f
commit 29ab3c2028
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,9 +1,10 @@
from inspect import isfunction
import math
from inspect import isfunction
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from torch import nn, einsum
from ldm.modules.diffusionmodules.util import checkpoint
@ -13,7 +14,7 @@ def exists(val):
def uniq(arr):
return{el: True for el in arr}.keys()
return {el: True for el in arr}.keys()
def default(val, d):
@ -82,14 +83,14 @@ class LinearAttention(nn.Module):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
@ -131,12 +132,12 @@ class SpatialSelfAttention(nn.Module):
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = w_ * (int(c) ** (-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
@ -146,7 +147,7 @@ class SpatialSelfAttention(nn.Module):
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
return x + h_
class CrossAttention(nn.Module):
@ -174,6 +175,7 @@ class CrossAttention(nn.Module):
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
device_type = x.device.type
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
@ -188,9 +190,11 @@ class CrossAttention(nn.Module):
sim.masked_fill_(~mask, max_neg_value)
del mask
# attention, what we cannot get enough of, by halves
sim[4:] = sim[4:].softmax(dim=-1)
sim[:4] = sim[:4].softmax(dim=-1)
if device_type == 'mps': #special case for M1 - disable neonsecret optimization
sim = sim.softmax(dim=-1)
else:
sim[4:] = sim[4:].softmax(dim=-1)
sim[:4] = sim[:4].softmax(dim=-1)
sim = einsum('b i j, b j d -> b i d', sim, v)
sim = rearrange(sim, '(b h) n d -> b n (h d)', h=h)
@ -200,7 +204,8 @@ class CrossAttention(nn.Module):
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
super().__init__()
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head,
dropout=dropout) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
@ -228,6 +233,7 @@ class SpatialTransformer(nn.Module):
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None):
super().__init__()
@ -243,7 +249,7 @@ class SpatialTransformer(nn.Module):
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)]
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,