Update attention.py for 16-32GB M1 performance (#540)

Code cleanup and attention.py einsum_ops update for M1 16-32GB performance.
Expected: On par with fastest ever from 8 to 128GB for 512x512. Allows large images.
This commit is contained in:
Any-Winter-4079 2022-09-13 16:53:45 +02:00 committed by GitHub
parent e1a6d0c138
commit d0a71dc361
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -167,30 +167,25 @@ class CrossAttention(nn.Module):
nn.Linear(inner_dim, query_dim), nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout) nn.Dropout(dropout)
) )
if not torch.cuda.is_available(): if torch.cuda.is_available():
mem_av = psutil.virtual_memory().available / (1024**3) self.einsum_op = self.einsum_op_cuda
if mem_av > 32:
self.einsum_op = self.einsum_op_v1
elif mem_av > 12:
self.einsum_op = self.einsum_op_v2
else:
self.einsum_op = self.einsum_op_v3
del mem_av
else: else:
self.einsum_op = self.einsum_op_v4 self.mem_total = psutil.virtual_memory().total / (1024**3)
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
# mps 64-128 GB def einsum_op_compvis(self, q, k, v, r1):
def einsum_op_v1(self, q, k, v, r1): s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096 s2 = s1.softmax(dim=-1, dtype=q.dtype)
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go del s1
s2 = s1.softmax(dim=-1, dtype=q.dtype) r1 = einsum('b i j, b j d -> b i d', s2, v)
del s1 del s2
r1 = einsum('b i j, b j d -> b i d', s2, v) return r1
del s2
def einsum_op_mps_v1(self, q, k, v, r1):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
r1 = self.einsum_op_compvis(q, k, v, r1)
else: else:
# q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err
# needs around half of that slice_size to not generate noise
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size): for i in range(0, q.shape[1], slice_size):
end = i + slice_size end = i + slice_size
@ -201,33 +196,22 @@ class CrossAttention(nn.Module):
del s2 del s2
return r1 return r1
# mps 16-32 GB (can be optimized) def einsum_op_mps_v2(self, q, k, v, r1):
def einsum_op_v2(self, q, k, v, r1): if self.mem_total >= 8 and q.shape[1] <= 4096:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) r1 = self.einsum_op_compvis(q, k, v, r1)
for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps else:
end = i + slice_size slice_size = 1
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale for i in range(0, q.shape[0], slice_size):
s2 = s1.softmax(dim=-1, dtype=r1.dtype) end = min(q.shape[0], i + slice_size)
del s1 s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) s1 *= self.scale
del s2 s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
return r1 return r1
# mps 8 GB def einsum_op_cuda(self, q, k, v, r1):
def einsum_op_v3(self, q, k, v, r1):
slice_size = 1
for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0]
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem
del s2
return r1
# cuda
def einsum_op_v4(self, q, k, v, r1):
stats = torch.cuda.memory_stats(q.device) stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current'] mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']