Refactor attention.CrossAttention to remove duplicate code and apply optimizations

Apply ~6% speedup by moving * self.scale to earlier on a smaller tensor.
When we have enough VRAM don't make a useless zeros tensor.
Switch between cuda/mps/cpu based on q.device.type to allow cleaner per architecture future optimizations.
For cuda and cpu keep VRAM usage and faster slicing consistent.
For cpu use smaller slices. Tested ~20% faster on i7, 9.8 to 7.7 s/it.
Fix = typo to self.mem_total >= 8 in einsum_op_mps_v2 as per #582 discussion.
This commit is contained in:
Mihail Dumitrescu 2022-09-14 20:25:56 +03:00 committed by Mihai
parent 100f2e8f57
commit e0951f28cf

View File

@ -90,7 +90,7 @@ class LinearAttention(nn.Module):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
@ -167,101 +167,85 @@ class CrossAttention(nn.Module):
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
if torch.cuda.is_available():
self.einsum_op = self.einsum_op_cuda
else:
self.mem_total = psutil.virtual_memory().total / (1024**3)
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
def einsum_op_compvis(self, q, k, v, r1):
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1
r1 = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
self.mem_total_gb = psutil.virtual_memory().total // (1 << 30)
def einsum_op_mps_v1(self, q, k, v, r1):
def einsum_op_compvis(self, q, k, v):
s = einsum('b i d, b j d -> b i j', q, k)
s = s.softmax(dim=-1, dtype=s.dtype)
return einsum('b i j, b j d -> b i d', s, v)
def einsum_op_slice_0(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[0], slice_size):
end = i + slice_size
r[i:end] = self.einsum_op_compvis(q[i:end], k[i:end], v[i:end])
return r
def einsum_op_slice_1(self, q, k, v, slice_size):
r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
r[:, i:end] = self.einsum_op_compvis(q[:, i:end], k, v)
return r
def einsum_op_mps_v1(self, q, k, v):
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
r1 = self.einsum_op_compvis(q, k, v, r1)
return self.einsum_op_compvis(q, k, v)
else:
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
return self.einsum_op_slice_1(q, k, v, slice_size)
def einsum_op_mps_v2(self, q, k, v, r1):
if self.mem_total >= 8 and q.shape[1] <= 4096:
r1 = self.einsum_op_compvis(q, k, v, r1)
def einsum_op_mps_v2(self, q, k, v):
if self.mem_total_gb > 8 and q.shape[1] <= 4096:
return self.einsum_op_compvis(q, k, v)
else:
slice_size = 1
for i in range(0, q.shape[0], slice_size):
end = min(q.shape[0], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
s1 *= self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
del s2
return r1
def einsum_op_cuda(self, q, k, v, r1):
return self.einsum_op_slice_0(q, k, v, 1)
def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb):
size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20)
if size_mb <= max_tensor_mb:
return self.einsum_op_compvis(q, k, v)
div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length()
if div <= q.shape[0]:
return self.einsum_op_slice_0(q, k, v, q.shape[0] // div)
return self.einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1))
def einsum_op_cuda(self, q, k, v):
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_cuda, _ = torch.cuda.mem_get_info(q.device)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch
# Divide factor of safety as there's copying and fragmentation
return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20))
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4
mem_required = tensor_size * 2.5
steps = 1
def einsum_op(self, q, k, v):
if q.device.type == 'cuda':
return self.einsum_op_cuda(q, k, v)
if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
if q.device.type == 'mps':
if self.mem_total_gb >= 32:
return self.einsum_op_mps_v1(q, k, v)
return self.einsum_op_mps_v2(q, k, v)
if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = min(q.shape[1], i + slice_size)
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
del s1
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
return r1
# Smaller slices are faster due to L2/L3/SLC caches.
# Tested on i7 with 8MB L3 cache.
return self.einsum_op_tensor_mem(q, k, v, 32)
def forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
q = self.to_q(x)
context = default(context, x)
k_in = self.to_k(context)
v_in = self.to_v(context)
device_type = 'mps' if x.device.type == 'mps' else 'cuda'
k = self.to_k(context) * self.scale
v = self.to_v(context)
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
r1 = self.einsum_op(q, k, v, r1)
del q, k, v
r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1
return self.to_out(r2)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
r = self.einsum_op(q, k, v)
return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
class BasicTransformerBlock(nn.Module):