mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix flake8 error
This commit is contained in:
parent
a536719fc3
commit
11b670755d
@ -4,6 +4,7 @@ Utility routine used for autodetection of optimal slice size
|
||||
for attention mechanism.
|
||||
"""
|
||||
import torch
|
||||
import psutil
|
||||
|
||||
|
||||
def auto_detect_slice_size(latents: torch.Tensor) -> str:
|
||||
@ -16,6 +17,13 @@ def auto_detect_slice_size(latents: torch.Tensor) -> str:
|
||||
* latents.size(dim=3)
|
||||
* bytes_per_element_needed_for_baddbmm_duplication
|
||||
)
|
||||
if latents.device.type in {"cpu", "mps"}:
|
||||
mem_free = psutil.virtual_memory().free
|
||||
elif latents.device.type == "cuda":
|
||||
mem_free, _ = torch.cuda.mem_get_info(latents.device)
|
||||
else:
|
||||
raise ValueError(f"unrecognized device {latents.device}")
|
||||
|
||||
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0):
|
||||
return "max"
|
||||
elif torch.backends.mps.is_available():
|
||||
|
Loading…
Reference in New Issue
Block a user