fix flake8 error

This commit is contained in:
Lincoln Stein 2023-08-20 15:39:45 -04:00
parent a536719fc3
commit 11b670755d

View File

@ -4,6 +4,7 @@ Utility routine used for autodetection of optimal slice size
for attention mechanism.
"""
import torch
import psutil
def auto_detect_slice_size(latents: torch.Tensor) -> str:
@ -16,6 +17,13 @@ def auto_detect_slice_size(latents: torch.Tensor) -> str:
* latents.size(dim=3)
* bytes_per_element_needed_for_baddbmm_duplication
)
if latents.device.type in {"cpu", "mps"}:
mem_free = psutil.virtual_memory().free
elif latents.device.type == "cuda":
mem_free, _ = torch.cuda.mem_get_info(latents.device)
else:
raise ValueError(f"unrecognized device {latents.device}")
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0):
return "max"
elif torch.backends.mps.is_available():