mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Update model.py
This commit is contained in:
parent
9cdf3aca7d
commit
25d9ccc509
@ -210,10 +210,7 @@ class AttnBlock(nn.Module):
|
|||||||
h_ = torch.zeros_like(k, device=q.device)
|
h_ = torch.zeros_like(k, device=q.device)
|
||||||
|
|
||||||
device_type = 'mps' if q.device.type == 'mps' else 'cuda'
|
device_type = 'mps' if q.device.type == 'mps' else 'cuda'
|
||||||
|
if device_type == 'cuda':
|
||||||
if device_type == 'mps':
|
|
||||||
mem_free_total = psutil.virtual_memory().available
|
|
||||||
else:
|
|
||||||
stats = torch.cuda.memory_stats(q.device)
|
stats = torch.cuda.memory_stats(q.device)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@ -229,6 +226,13 @@ class AttnBlock(nn.Module):
|
|||||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||||
|
|
||||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||||
|
|
||||||
|
else:
|
||||||
|
if psutil.virtual_memory().available / (1024**3) < 12:
|
||||||
|
slice_size = 1
|
||||||
|
else:
|
||||||
|
slice_size = min(q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1])))
|
||||||
|
|
||||||
for i in range(0, q.shape[1], slice_size):
|
for i in range(0, q.shape[1], slice_size):
|
||||||
end = i + slice_size
|
end = i + slice_size
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user