Merge branch 'development' into development

This commit is contained in:
Lincoln Stein 2022-09-13 10:57:05 -04:00 committed by GitHub
commit 0d1aad53ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -167,30 +167,25 @@ class CrossAttention(nn.Module):
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
if not torch.cuda.is_available():
mem_av = psutil.virtual_memory().available / (1024**3)
if mem_av > 32:
self.einsum_op = self.einsum_op_v1
elif mem_av > 12:
self.einsum_op = self.einsum_op_v2
else:
self.einsum_op = self.einsum_op_v3
del mem_av
if torch.cuda.is_available():
self.einsum_op = self.einsum_op_cuda
else:
self.einsum_op = self.einsum_op_v4
self.mem_total = psutil.virtual_memory().total / (1024**3)
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
# mps 64-128 GB
def einsum_op_v1(self, q, k, v, r1):
if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
def einsum_op_compvis(self, q, k, v, r1):
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
def einsum_op_mps_v1(self, q, k, v, r1):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
r1 = self.einsum_op_compvis(q, k, v, r1)
else:
# q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err
# needs around half of that slice_size to not generate noise
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
@ -201,33 +196,22 @@ class CrossAttention(nn.Module):
del s2
return r1
# mps 16-32 GB (can be optimized)
def einsum_op_v2(self, q, k, v, r1):
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
def einsum_op_mps_v2(self, q, k, v, r1):
if self.mem_total >= 8 and q.shape[1] <= 4096:
r1 = self.einsum_op_compvis(q, k, v, r1)
else:
slice_size = 1
for i in range(0, q.shape[0], slice_size):
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
return r1
# mps 8 GB
def einsum_op_v3(self, q, k, v, r1):
slice_size = 1
for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0]
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem
del s2
return r1
# cuda
def einsum_op_v4(self, q, k, v, r1):
def einsum_op_cuda(self, q, k, v, r1):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']