# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team """ Utility routine used for autodetection of optimal slice size for attention mechanism. """ import torch def auto_detect_slice_size(latents: torch.Tensor) -> str: bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4 max_size_required_for_baddbmm = ( 16 * latents.size(dim=2) * latents.size(dim=3) * latents.size(dim=2) * latents.size(dim=3) * bytes_per_element_needed_for_baddbmm_duplication ) if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0): return "max" elif torch.backends.mps.is_available(): return "max" else: return "balanced"