From d0a71dc3615036b2041fbeb31107f2c8364a8d36 Mon Sep 17 00:00:00 2001 From: Any-Winter-4079 <50542132+Any-Winter-4079@users.noreply.github.com> Date: Tue, 13 Sep 2022 16:53:45 +0200 Subject: [PATCH] Update attention.py for 16-32GB M1 performance (#540) Code cleanup and attention.py einsum_ops update for M1 16-32GB performance. Expected: On par with fastest ever from 8 to 128GB for 512x512. Allows large images. --- ldm/modules/attention.py | 78 ++++++++++++++++------------------------ 1 file changed, 31 insertions(+), 47 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 894c4db839..55e5b9f8a4 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -167,30 +167,25 @@ class CrossAttention(nn.Module): nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) - - if not torch.cuda.is_available(): - mem_av = psutil.virtual_memory().available / (1024**3) - if mem_av > 32: - self.einsum_op = self.einsum_op_v1 - elif mem_av > 12: - self.einsum_op = self.einsum_op_v2 - else: - self.einsum_op = self.einsum_op_v3 - del mem_av + + if torch.cuda.is_available(): + self.einsum_op = self.einsum_op_cuda else: - self.einsum_op = self.einsum_op_v4 + self.mem_total = psutil.virtual_memory().total / (1024**3) + self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2 - # mps 64-128 GB - def einsum_op_v1(self, q, k, v, r1): - if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096 - s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - r1 = einsum('b i j, b j d -> b i d', s2, v) - del s2 + def einsum_op_compvis(self, q, k, v, r1): + s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + r1 = einsum('b i j, b j d -> b i d', s2, v) + del s2 + return r1 + + def einsum_op_mps_v1(self, q, k, v, r1): + if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 + r1 = self.einsum_op_compvis(q, k, v, r1) else: - # q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err - # needs around half of that slice_size to not generate noise slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) for i in range(0, q.shape[1], slice_size): end = i + slice_size @@ -201,33 +196,22 @@ class CrossAttention(nn.Module): del s2 return r1 - # mps 16-32 GB (can be optimized) - def einsum_op_v2(self, q, k, v, r1): - slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) - for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 + def einsum_op_mps_v2(self, q, k, v, r1): + if self.mem_total >= 8 and q.shape[1] <= 4096: + r1 = self.einsum_op_compvis(q, k, v, r1) + else: + slice_size = 1 + for i in range(0, q.shape[0], slice_size): + end = min(q.shape[0], i + slice_size) + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + s2 = s1.softmax(dim=-1, dtype=r1.dtype) + del s1 + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 return r1 - - # mps 8 GB - def einsum_op_v3(self, q, k, v, r1): - slice_size = 1 - for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0] - end = min(q.shape[0], i + slice_size) - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem - s1 *= self.scale - s2 = s1.softmax(dim=-1, dtype=r1.dtype) - del s1 - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem - del s2 - return r1 - - # cuda - def einsum_op_v4(self, q, k, v, r1): + + def einsum_op_cuda(self, q, k, v, r1): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] mem_reserved = stats['reserved_bytes.all.current']