add support for "balanced" attention slice size

This commit is contained in:
Lincoln Stein
2023-08-17 16:11:09 -04:00
parent 23b4e1cea0
commit b69f26c85c
7 changed files with 38 additions and 12 deletions

View File

@ -12,3 +12,4 @@ from .devices import (
)
from .log import write_log
from .util import ask_user, download_with_resume, instantiate_from_config, url_attachment_name, Chdir
from .attention import auto_detect_slice_size

View File

@ -0,0 +1,24 @@
# Copyright (c) 2023 Lincoln Stein and the InvokeAI Team
"""
Utility routine used for autodetection of optimal slice size
for attention mechanism.
"""
import torch
def auto_detect_slice_size(latents: torch.Tensor) -> str:
bytes_per_element_needed_for_baddbmm_duplication = latents.element_size() + 4
max_size_required_for_baddbmm = (
16
* latents.size(dim=2)
* latents.size(dim=3)
* latents.size(dim=2)
* latents.size(dim=3)
* bytes_per_element_needed_for_baddbmm_duplication
)
if max_size_required_for_baddbmm > (mem_free * 3.0 / 4.0):
return "max"
elif torch.backends.mps.is_available():
return "max"
else:
return "balanced"