mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into development
This commit is contained in:
commit
0d1aad53ef
@ -167,30 +167,25 @@ class CrossAttention(nn.Module):
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
mem_av = psutil.virtual_memory().available / (1024**3)
|
||||
if mem_av > 32:
|
||||
self.einsum_op = self.einsum_op_v1
|
||||
elif mem_av > 12:
|
||||
self.einsum_op = self.einsum_op_v2
|
||||
else:
|
||||
self.einsum_op = self.einsum_op_v3
|
||||
del mem_av
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self.einsum_op = self.einsum_op_cuda
|
||||
else:
|
||||
self.einsum_op = self.einsum_op_v4
|
||||
self.mem_total = psutil.virtual_memory().total / (1024**3)
|
||||
self.einsum_op = self.einsum_op_mps_v1 if self.mem_total >= 32 else self.einsum_op_mps_v2
|
||||
|
||||
# mps 64-128 GB
|
||||
def einsum_op_v1(self, q, k, v, r1):
|
||||
if q.shape[1] <= 4096: # for 512x512: the max q.shape[1] is 4096
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # aggressive/faster: operation in one go
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
def einsum_op_compvis(self, q, k, v, r1):
|
||||
s1 = einsum('b i d, b j d -> b i j', q, k) * self.scale # faster
|
||||
s2 = s1.softmax(dim=-1, dtype=q.dtype)
|
||||
del s1
|
||||
r1 = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
return r1
|
||||
|
||||
def einsum_op_mps_v1(self, q, k, v, r1):
|
||||
if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
else:
|
||||
# q.shape[0] * q.shape[1] * slice_size >= 2**31 throws err
|
||||
# needs around half of that slice_size to not generate noise
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
@ -201,33 +196,22 @@ class CrossAttention(nn.Module):
|
||||
del s2
|
||||
return r1
|
||||
|
||||
# mps 16-32 GB (can be optimized)
|
||||
def einsum_op_v2(self, q, k, v, r1):
|
||||
slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
|
||||
for i in range(0, q.shape[1], slice_size): # conservative/less mem: operation in steps
|
||||
end = i + slice_size
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
def einsum_op_mps_v2(self, q, k, v, r1):
|
||||
if self.mem_total >= 8 and q.shape[1] <= 4096:
|
||||
r1 = self.einsum_op_compvis(q, k, v, r1)
|
||||
else:
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size):
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end])
|
||||
del s2
|
||||
return r1
|
||||
|
||||
# mps 8 GB
|
||||
def einsum_op_v3(self, q, k, v, r1):
|
||||
slice_size = 1
|
||||
for i in range(0, q.shape[0], slice_size): # iterate over q.shape[0]
|
||||
end = min(q.shape[0], i + slice_size)
|
||||
s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) # adapted einsum for mem
|
||||
s1 *= self.scale
|
||||
s2 = s1.softmax(dim=-1, dtype=r1.dtype)
|
||||
del s1
|
||||
r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) # adapted einsum for mem
|
||||
del s2
|
||||
return r1
|
||||
|
||||
# cuda
|
||||
def einsum_op_v4(self, q, k, v, r1):
|
||||
|
||||
def einsum_op_cuda(self, q, k, v, r1):
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
|
Loading…
Reference in New Issue
Block a user