From 11b670755d5095b8208eeadeb4d3e5aa36f59f1a Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Sun, 20 Aug 2023 15:39:45 -0400 Subject: [PATCH] fix flake8 error --- invokeai/backend/util/attention.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/invokeai/backend/util/attention.py b/invokeai/backend/util/attention.py index ef80898c2e..a821464394 100644 --- a/invokeai/backend/util/attention.py +++ b/invokeai/backend/util/attention.py @@ -4,6 +4,7 @@ Utility routine used for autodetection of optimal slice size for attention mechanism. """ import torch +import psutil def auto_detect_slice_size(latents: torch.Tensor) -> str: @@ -16,6 +17,13 @@ def auto_detect_slice_size(latents: torch.Tensor) -> str: * latents.size(dim=3) * bytes_per_element_needed_for_baddbmm_duplication ) + if latents.device.type in {"cpu", "mps"}: + mem_free = psutil.virtual_memory().free + elif latents.device.type == "cuda": + mem_free, _ = torch.cuda.mem_get_info(latents.device) + else: + raise ValueError(f"unrecognized device {latents.device}") + if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): return "max" elif torch.backends.mps.is_available():