mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update model.py
This commit is contained in:
parent
9cdf3aca7d
commit
25d9ccc509
@ -210,10 +210,7 @@ class AttnBlock(nn.Module):
|
||||
h_ = torch.zeros_like(k, device=q.device)
|
||||
|
||||
device_type = 'mps' if q.device.type == 'mps' else 'cuda'
|
||||
|
||||
if device_type == 'mps':
|
||||
mem_free_total = psutil.virtual_memory().available
|
||||
else:
|
||||
if device_type == 'cuda':
|
||||
stats = torch.cuda.memory_stats(q.device)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -221,14 +218,21 @@ class AttnBlock(nn.Module):
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_cuda + mem_free_torch
|
||||
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4
|
||||
mem_required = tensor_size * 2.5
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
else:
|
||||
if psutil.virtual_memory().available / (1024**3) < 12:
|
||||
slice_size = 1
|
||||
else:
|
||||
slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1])))
|
||||
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user